Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
from ..model_parallel_config import ModelParallelConfig
from ..utils import init_method_normal, scaled_init_method_normal
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter, including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers: int = 0
"""Number of transformer layers in a transformer block."""
hidden_size: int = 0
"""Transformer hidden size."""
num_attention_heads: int = 0
"""Number of transformer attention heads."""
num_query_groups: int = None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: int = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided."""
kv_channels: int = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""
attention_dropout: float = 0.1
"""Post attention dropout probability."""
fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma: bool = False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear: bool = True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""
gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func: Callable = F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store: bool = False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: int = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
#normalization: bool = "LayerNorm"
normalization: bool = "RMSNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
"""Whether to apply LayerNorm to the query and key embeddings."""
test_mode: bool = False
"""Whether to run real-time tests."""
calculate_per_token_loss: bool = False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
####################
# initialization
####################
init_method: Callable = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Callable = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""
persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: str = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: int = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: bool = None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8: str = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_margin: int = 0
"""Margin for the scaling factor computation."""
fp8_interval: int = 1
"""Controls how often the scaling factor is recomputed."""
fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo: str = "most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation in higher precision."""
fp8_dot_product_attention: bool = False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention: bool = False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
####################
# MoE related
####################
moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing."""
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: float = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False # TODO: Support token dropping.
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type: str = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather' and 'alltoall'."""
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: float = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None."""
moe_pad_expert_input_to_capacity: bool = False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match the expert capacity length, effective only after the moe_expert_capacity_factor is set. The default setting is False."""
moe_token_drop_policy: str = 'probs'
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
"""
moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer blocks are wrapped with CUDA graph."""
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(
f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
)
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')
if self.num_moe_experts is not None and self.num_moe_experts <= 0:
raise ValueError(f'num_moe_experts must be non-negative.')
if self.moe_expert_capacity_factor is not None:
if self.moe_token_dispatcher_type != "alltoall":
raise ValueError(
f'moe_expert_capacity_factor only works with alltoall token dispatcher'
)
if self.moe_expert_capacity_factor < 0:
self.moe_expert_capacity_factor = None
if self.moe_router_load_balancing_type not in ["aux_loss", "none"]:
raise ValueError(
f'moe_expert_capacity_factor only works with aux_loss or none load balancing'
)
if self.moe_pad_expert_input_to_capacity:
if self.moe_expert_capacity_factor is None:
raise ValueError(
f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity'
)
if self.cpu_offloading and (
self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
):
raise ValueError(
f'CPU offloading can be done only for layers less than {self.num_layers}'
)
if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
raise ValueError(
f'Currently there is no support for Pipeline parallelism with CPU offloading'
)
if self.cpu_offloading and self.recompute_granularity is not None:
raise ValueError(
f'CPU offloading does not work when activation recomputation is enabled'
)
if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
)
if self.recompute_method is not None:
if not self.recompute_method in ['block', 'uniform']:
raise ValueError(
f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
)
elif self.recompute_granularity != 'selective':
raise ValueError(
f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
)
if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
)
elif (
self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
):
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
)
if self.distribute_saved_activations and self.sequence_parallel:
raise ValueError(
f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
)
if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_activation_fusion:
if self.activation_func not in [F.gelu, F.silu]:
raise ValueError(
"When bias_activation_fusion is True, activation function should be either gelu or swiglu"
)
if (
self.activation_func == F.gelu
and not self.gated_linear_unit
and not self.add_bias_linear
):
raise ValueError(
"When bias_activation_fusion is True, gated_linear_unit is False, "
"and activation function is gelu, add_bias_linear must also be True."
)
if self.activation_func_fp8_input_store:
if self.activation_func != F.silu or not self.gated_linear_unit:
raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")
if self.apply_rope_fusion and self.rotary_interleaved:
raise ValueError(f'rotary_interleaved does not work with apply_rope_fusion.')
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers
)
if self.moe_extended_tp:
if self.moe_token_dispatcher_type != 'allgather':
raise ValueError(
"Moe extended TP parallelism only applies to allgather based token dispatcher."
)
extended_tp_size = self.tensor_model_parallel_size * self.expert_model_parallel_size
if self.ffn_hidden_size % extended_tp_size != 0:
raise ValueError(
f'ffn_hidden_size: {self.ffn_hidden_size} must be divisible by extended_tp_size {extended_tp_size}'
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
@dataclass
class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class BaseTransformerLayer(ABC):
""" A common parent class for `TransformerLayer` like implementations.
A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
`TransformerLayer` in this file and possibly other `TransformerLayer`
implementations that aim to use `TransformerBlock` as the base module.
The main purpose is to check if any layer (or module) provided in the spec
is a subclass of this class to allow fanning-out of that spec for all the
layers in the `TransformerBlock`. See `_get_block_submodules` method
implementation in `transformer_block.py` file for more details.
"""
def __init__(self):
pass
class TransformerLayer(MegatronModule, BaseTransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__(config=config)
self.submodules_config = submodules
self.layer_number = layer_number + self._get_layer_offset()
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
## [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention, config=self.config, layer_number=layer_number,
)
## [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
## [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention, config=self.config, layer_number=layer_number,
)
## [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,)
## [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 8: MLP block]
# TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
# where MLP and MoE layer both appear alternately?
self.mlp = build_module(submodules.mlp, config=self.config)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)
## [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# hidden_states: [s, b, h]
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_params=inference_params,
)
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
# MLP.
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output, context
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transformer layers."""
from functools import lru_cache
from operator import itemgetter
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict, StateDict
from megatron.core.jit import jit_fuser
from megatron.core.utils import (
make_sharded_tensor_for_checkpoint,
make_tp_sharded_tensor_for_checkpoint,
)
def get_linear_layer(rows, columns, init_method, perform_initialization=True):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
if perform_initialization: # Take from modelparallel config
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@lru_cache(maxsize=32)
def get_default_causal_mask(sq: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input."""
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@jit_fuser
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@jit_fuser
def erf_gelu(x):
return (
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
)
def make_sharded_tensors_for_checkpoint(
state_dict: StateDict,
prefix: str,
tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
extra_state_suffix: str = '_extra_state',
):
"""Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
For a given `state_dict`, wraps:
- all _extra_states with ShardedObject
- all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor
- other values with DP sharded ShardedTensor
Args:
state_dict (StateDict): state_dict to convert
prefix (str): prefix appended to keys in final state dict
tensor_parallel_layers_axis_map (Dict[str, int], optional): dict mapping layer
names to the axis for TP sharding
sharded_offsets (Iterable[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related), passed along to ShardedTensor
extra_state_suffix (str, default = '_extra_state'): layers with this
suffix will be wrapped with ShardedObject instead of ShardedTensor.
"""
if tensor_parallel_layers_axis_map is None:
tensor_parallel_layers_axis_map = {}
sharded_state_dict = {}
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
layer_key = f'{prefix}{layer_name}'
if layer_name.endswith(extra_state_suffix):
sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint(
tensor, layer_key, sharded_offsets
)
elif layer_name in tensor_parallel_layers_axis_map:
tp_axis = tensor_parallel_layers_axis_map[layer_name]
sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint(
tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets,
)
else:
sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
tensor, layer_key, prepend_offsets=sharded_offsets,
)
return sharded_state_dict
def make_sharded_object_for_checkpoint(
obj: Any,
key: str,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
replica_id: Union[None, int, Tuple[int, ...]] = None,
**kwargs,
):
""" Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
Args:
obj (object): any object to be sharded
key (str): unique identifier of the object
sharded_offsets (Iterable[Tuple[int, int, int]]): offsets normally
prepended to ShardedTensors, will be used as global offsets for
ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]): replica id
"""
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedObject(key, obj, *_get_extra_state_offsets(sharded_offsets), replica_id, **kwargs)
def _get_extra_state_offsets(
sharded_offsets: Iterable[Tuple[int, int, int]]
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
""" Turns ShardedTensor offsets into offsets suitable for ShardedObject. """
if sharded_offsets:
sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis
axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets)
assert list(axis) == list(
range(len(axis))
), f'Expected contiguous axis for offsets: {sharded_offsets}'
else:
extra_state_shape = (1,)
extra_state_offset = (0,)
return extra_state_shape, extra_state_offset
def sharded_state_dict_default(
module: torch.nn.Module,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Provides implementation for sharded_state_dict method for non-MegatronModules.
Tries to call `module.sharded_state_dict` when possible,
otherwise uses regular state dict and assumes tensors are replicated across TP and DP.
`keep_vars=True` is passed to module.state_dict so that optimizer states
can be sharded later on.
Args:
module (torch.nn.Module): module which sharded state dict we want to obtain
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed to module sharded_state_dict method
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
if hasattr(module, 'sharded_state_dict'):
module_sharded_sd = module.sharded_state_dict(
prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata
)
else:
module_sd = module.state_dict(prefix='', keep_vars=True)
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd, prefix, {}, sharded_offsets,
)
return module_sharded_sd
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
import array
import hashlib
import logging
import math
import operator
import queue
import socket
import sys
import threading
import time
import traceback
from dataclasses import dataclass
from datetime import datetime
from functools import reduce
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor
logger = logging.getLogger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
"""Get an attribute from a wrapped model.
If return_model_obj is true, return the object that has the 'attr' attribute;
otherwise, return the attribute directly."""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
if return_model_obj:
return model
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
def get_model_config(model):
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if (
self.buffer.get((name, dtype), None) is None
or self.buffer[(name, dtype)].numel() < required_len
):
self.buffer[(name, dtype)] = torch.empty(
required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg=None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[assert_viewless_tensor(t) for t in tensor]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(
tensor,
extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
% ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
)
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
"""If torch distributed is initialized, log only on rank
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
rank (int, optional): The rank to write on. Defaults to 0.
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == rank:
logger.log(*args, **kwargs)
else:
logger.log(*args, **kwargs)
def log_on_each_pipeline_stage(logger: logging.Logger, *args: Any, **kwargs: Any):
"""Log on first rank in each pipeline stage
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
assert torch.distributed.is_initialized()
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
logger.log(*args, **kwargs)
def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and the hashes
from DP replica 0.
NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.
Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.
Returns:
True if all param hashes match with corresponding hash on DP replica 0, False
otherwise.
"""
# Compute per-parameter hashes on this rank.
params = []
local_param_hashes = []
for model_chunk_id, model_chunk in enumerate(model):
for (param_name, param) in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
params.append((model_chunk_id, param_name, param))
local_param_hashes.append(param_hash)
local_param_hashes = torch.stack(local_param_hashes)
# Collect per-parameter hashes across all ranks in DP group.
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(parallel_state.get_data_parallel_world_size())
]
torch.distributed.all_gather(
all_param_hashes, local_param_hashes, group=parallel_state.get_data_parallel_group_gloo()
)
# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk {model_chunk_id}"
)
return param_hashes_match
def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
""" Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group.
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
(
tp_axis + prepend_axis_num,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs):
""" Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group).
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input):
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if grad_output.dim() == 3:
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
all_gathered_input = all_gathered_input.view(
all_gathered_input.shape[0] * all_gathered_input.shape[1], all_gathered_input.shape[2]
)
return grad_output, all_gathered_input
def drain_embedding_wgrad_compute(config, embedding_activation_buffer, grad_output_buffer, weight):
""" Helper for performing embedding wgrad GEMM's during the pipeline drain phase, pipelines the AllGather and GEMM's.
Should only be used when pipeline model parallelism and gradient accumulation fusion are enabled.
"""
assert len(embedding_activation_buffer) == len(
grad_output_buffer
), "Length of activation and gradient buffers need to be equal!"
import fused_weight_gradient_mlp_cuda
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
input = embedding_activation_buffer.pop(0)
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gathered_input = [None, None]
if config.sequence_parallel:
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu_0")
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=False
)
all_gathered_input[0] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[0] = input
input = None
def wgrad_compute(all_gathered_input, grad_output, weight):
grad_output, all_gathered_input = prepare_input_tensors_for_wgrad_compute(
grad_output, all_gathered_input
)
if config.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
all_gathered_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
all_gathered_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
# We have all_gathered_input list acting as a double buffer here,
# since we are pipelining the AllGather and GEMM,one buffer all gathers
# the input while the other buffer reads from it for the GEMM. We use i
# and (i+1) for indexing to enable this double buffering.
for i in range(len(embedding_activation_buffer)):
input = embedding_activation_buffer.pop(0)
if config.sequence_parallel:
name = "mpu_" + str((i + 1) % 2)
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, name)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
all_gathered_input[(i + 1) % 2] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[(i + 1) % 2] = input
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[i % 2], grad_output, weight)
input, all_gathered_input[i % 2], grad_output = None, None, None
if config.sequence_parallel:
handle.wait()
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[1], grad_output, weight)
input, all_gathered_input[1], grad_output = None, None, None
class _ValueWithRank:
"""This is an internal class, not for use outside this module
Attributes:
_rank (int): rank for the value
_value (float) : the value it stores, eg elapsed time
_unit (str) : unit for the value
"""
def __init__(self, value: float, rank: int, unit: str = "") -> None:
"""Initializer
Args:
_value (float): the initial value with which it is inited
_rank (int): the rank number
_unit (str) : the unit of the value, eg ms or flops
"""
self._rank = rank
self._value = value
self._unit = unit
def __lt__(self, other) -> bool:
""" Check if value of self is smaller than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is less than rhs._value, else False
"""
return self._value < other._value
def __gt__(self, other) -> bool:
"""Check if value of self is larger than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is greater than rhs._value, else False
"""
return self._value > other._value
def __call__(self) -> Tuple[float, int, str]:
"""Returns the value, the rank, and unit as a Tuple
Returns:
Tuple[float, int, str]: value, rank, unit
"""
return self._value, self._rank, self._unit
def __str__(self) -> str:
"""String representation of the object
Returns:
str: strigified object
"""
return f"{self._value:.2f}{self._unit}/{self._rank}"
@dataclass
class _StragglerData:
"""This is an internal dataclass, not for use outside this module
Attributes:
min_elapsed (_ValueWithRank) min iteration time across all ranks
max_elapsed (_ValueWithRank) max iteration time across all ranks
min_btime (_ValueWithRank) min cpu time across all ranks
max_btime (_ValueWithRank) max cpu time across all ranks
min_temp (_ValueWithRank): min gpu temp across all ranks
max_temp (_ValueWithRank): max gpu temp across all ranks
min_power (_ValueWithRank) min gpu power across all ranks
max_power (_ValueWithRank) max gpu power across all ranks
min_util (_ValueWithRank): min gpu util across all ranks
max_util (_ValueWithRank): max gpu util across all ranks
min_clock (_ValueWithRank): min gpu clock across all ranks
max_clock (_ValueWithRank) max gpu clock across all ranks
aflops (List[_ValueWithRank]): sorted array of (_ValueWithRank)
"""
# gemm time
min_elapsed = _ValueWithRank(sys.float_info.max, 0, "ms")
max_elapsed = _ValueWithRank(sys.float_info.min, 0, "ms")
# get_batch time
min_btime = _ValueWithRank(sys.float_info.max, 0, "us")
max_btime = _ValueWithRank(sys.float_info.min, 0, "us")
# temp
min_temp = _ValueWithRank(sys.float_info.max, 0, "C")
max_temp = _ValueWithRank(sys.float_info.min, 0, "C")
# power
min_power = _ValueWithRank(sys.float_info.max, 0, "W")
max_power = _ValueWithRank(sys.float_info.min, 0, "W")
# util
min_util = _ValueWithRank(sys.float_info.max, 0, "%")
max_util = _ValueWithRank(sys.float_info.min, 0, "%")
# clock
min_clock = _ValueWithRank(sys.float_info.max, 0, "MHz")
max_clock = _ValueWithRank(sys.float_info.min, 0, "MHz")
aflops: Union[List[_ValueWithRank], None] = None
class StragglerDetector:
"""Singleton Class implementing per rank Straggler Detector
It use cuda events to time operation of choice using the
start and stop methods which can be directly invoked using
the class instance or can be used like a python context.
After collection, a report() method is available to display
the collected metrics. It is only supported if CUDA is
available. megatron/core/README_STRAGGLER.md for more info
Note:
The instance and class attributes mentioned below are all
private to the class and has no use outside the class
Attributes:
_off (bool): current state of the toggle
start (FunctionType): start method
stop (FunctionType): stop method
world (int): world size
rank (int): rank for this instance
mmcnt (int): number of ranks to report
port (int): control port
amp (float): amplification factor for TFLOPs, default 3.0
toggle (bool): whether to start/stop detector collection
bdata (bool): when true, just collect get_batch
dev (int): cuda device
evt_q (LifoQueue): cuda event queue
start_gemm_ev (list[torch.cuda.Event]): cuda start event
stop_gemm_ev (list[torch.cuda.Event]): cuda stop event
start_data_ev (list[torch.cuda.Event]): cuda start event
stop_data_ev (list[torch.cuda.Event]): cuda stop event
start_gemm_tm (list[int]): start time (wallclock)
stop_gemm_tm (list[int]): stop time (wallclock)
start_data_tm (list[int]): start time for get_batch
stop_data_tm (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
"""
_configured = False
"""Indicates if the singleton instance is configured or not
"""
def __new__(cls: Type["StragglerDetector"]) -> "StragglerDetector":
"""Constructor
Creates an instance of the class if not created
Args:
cls (Type[&#39;StragglerDetector&#39;]): The class type
Returns:
StragglerDetector: the class instance
"""
if not hasattr(cls, "_instance"):
cls._instance = super(StragglerDetector, cls).__new__(cls)
return cls._instance
def __init__(self) -> None:
"""Initializer
The inital state of the StragglerDetector instance is disabled.
The enabled state is indicated using self._off member variable
and the proerty enabled.
"""
self._off: bool = True
self.start = self.null_method
self.stop = self.null_method
self.world: int = 0
self.rank: int = 0
self.mmcnt: int = 1
self.port: int = 0
self.amp: float = 3.0
self.toggle: bool = False
self.bdata: bool = False
self.dev: Union[torch.device, int, None] = None
self.evt_q: Union[queue.LifoQueue, None] = None
self.start_gemm_ev: List[torch.cuda.Event] = []
self.stop_gemm_ev: List[torch.cuda.Event] = []
self.start_data_ev: List[torch.cuda.Event] = []
self.stop_data_ev: List[torch.cuda.Event] = []
self.start_gemm_tm: List[int] = []
self.stop_gemm_tm: List[int] = []
self.start_data_tm: List[int] = []
self.stop_data_tm: List[int] = []
self.sock: Union[socket.socket, None] = None
self.ctrlr: Union[threading.Thread, None] = None
def configure(
self,
world: int,
rank: int,
mmcnt: int = 1,
amp: float = 3.0,
port: int = 65535,
prefill: int = 1024,
enabled: bool = False,
) -> None:
"""This method is called to configure the Singleton instance
It should be called once per instantiation per process.
Note:
The constructor keeps the state of instance disabled
i.e no collection will happen even when start/stop methods are
called. Only when enabled is True (self._off is True), the
start/stop method pointers get assigned the real collection
methods, otherwise they are initialized with null_method
Args:
world (int): World Size
rank (int): The rank of this trainer
mmcnt (int, optional): Number of ranks to print for showing Min/Max Etpt.
Defaults to 1.
amp (float, optional): Set to 3.0 if we only use timers in fwd pass.
Defaults to 3.0.
port (int, optional): Control port, useful only for rank-0. Defaults to 65535.
prefill (int, optional): Howmany Events to pre-populate. Defaults to 1024.
enabled (bool, optional): Whether or not collection is enabled on startup.
Defaults to False.
"""
if StragglerDetector._configured:
# don't throw
return
StragglerDetector._configured = True
self.bdata = False
self.start = self.null_method
self.stop = self.null_method
self._off = True
# No CUDA, No Support
if torch.cuda.is_available():
self._off = not enabled
self.world = world
self.rank = rank
self.mmcnt = mmcnt if mmcnt > 1 else 1
self.amp = amp
self.port = port
self.toggle = False
self.bdata = False
self.evt_q = queue.LifoQueue()
self.start_gemm_ev = []
self.stop_gemm_ev = []
self.start_data_ev = []
self.stop_data_ev = []
self.start_gemm_tm = []
self.stop_gemm_tm = []
self.start_data_tm = []
self.stop_data_tm = []
backend = torch.distributed.get_backend()
if backend == "nccl":
self.dev = torch.cuda.current_device()
else:
self.dev = torch.device("cpu")
# cache some events
for _ in range(prefill):
self.evt_q.put(torch.cuda.Event(enable_timing=True))
if self.rank == 0:
# Start the controller
self._controller()
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
def reset(self) -> None:
"""This method is called to reset the metrics state of the instance
It is generally called from within elapsed() after extracting per rank metrics.
"""
if self._off:
return
# Pool them
if self.evt_q is not None:
_ = [self.evt_q.put(ev) for ev in self.start_gemm_ev]
_ = [self.evt_q.put(ev) for ev in self.stop_gemm_ev]
_ = [self.evt_q.put(ev) for ev in self.start_data_ev]
_ = [self.evt_q.put(ev) for ev in self.stop_data_ev]
self.start_gemm_ev = []
self.stop_gemm_ev = []
self.start_data_ev = []
self.stop_data_ev = []
# Use regular timers
self.start_gemm_tm = []
self.stop_gemm_tm = []
self.start_data_tm = []
self.stop_data_tm = []
self.bdata = False
def start_method(self) -> None:
"""This method adds the start timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. This way it can be used to measure time spent on
CPU - generally useful for timing get_batch()
"""
# Not reentrant
if self.evt_q is not None and self.evt_q.qsize() > 1:
sev = self.evt_q.get() # no try-catch
eev = self.evt_q.get() # no try-catch
else:
sev = torch.cuda.Event(enable_timing=True)
eev = torch.cuda.Event(enable_timing=True)
# First check if this start is for data
if self.bdata:
self.start_data_ev.append(sev)
self.stop_data_ev.append(eev)
self.start_data_tm.append(0)
self.stop_data_tm.append(0)
idx = len(self.stop_data_tm) - 1
self.start_data_tm[idx] = time.perf_counter_ns()
self.start_data_ev[idx].record()
self.bdata = False
return
self.start_gemm_ev.append(sev)
self.stop_gemm_ev.append(eev)
self.start_gemm_tm.append(0)
self.stop_gemm_tm.append(0)
idx = len(self.stop_gemm_tm) - 1
self.start_gemm_tm[idx] = time.perf_counter_ns()
self.start_gemm_ev[idx].record()
def stop_method(self) -> None:
"""This method adds the stop timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. Also see start_method()
"""
# Not reentrant
# First check if this stop is for data
idx = len(self.stop_data_tm) - 1
if idx >= 0 and self.stop_data_tm[idx] == 0:
self.stop_data_tm[idx] = time.perf_counter_ns()
self.stop_data_ev[idx].record()
return
idx = len(self.stop_gemm_tm) - 1
if idx >= 0 and self.stop_gemm_tm[idx] == 0:
self.stop_gemm_tm[idx] = time.perf_counter_ns()
self.stop_gemm_ev[idx].record()
def elapsed(self) -> Tuple[float, float, int, int, int, int]:
"""This method is called from report(), or can be called directly
It is called to collect all the elapsed time since last reset().
It finally calls reset()
Returns:
Tuple[float, float, int, int, int, int]: see below for returns
delta : time spent in kernel
batch_delta : time spent in get_batch
temp : observed gpu temp
power : observed gpu power
util : observed gpu utilization
clock : observed gpu clock
"""
if self._off:
# match with return below
return 0, 0, 0, 0, 0, 0
ls_ev = len(self.start_gemm_ev)
le_ev = len(self.stop_gemm_ev)
ls_bs = len(self.start_data_ev)
ls_be = len(self.stop_data_ev)
delta = 0.0
batch_delta = 0.0
temp = 0
power = 0
clock = 0
if ls_ev != le_ev:
logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
util = torch.cuda.utilization()
clock = torch.cuda.clock_rate()
torch.cuda.synchronize()
# Process Events
for i in range(ls_ev):
e_ev = self.start_gemm_ev[i].elapsed_time(self.stop_gemm_ev[i])
e_tm = (self.stop_gemm_tm[i] - self.start_gemm_tm[i]) / 1e6 # ns to ms
# Pick the larger of Event and perf_counter time?
delta += max(e_ev, e_tm)
# Process get_batch
for i in range(ls_bs):
b_ev = self.start_data_ev[i].elapsed_time(self.stop_data_ev[i])
b_tm = (self.stop_data_tm[i] - self.start_data_tm[i]) / 1e6 # ns to ms
# data fetching has prefetch, hence take the max, instead of avg
batch_delta = max(batch_delta, max(b_ev, b_tm))
self.reset() # Prepare for next round
# time in ms, batch_delta in ms, check return above
return delta, batch_delta, temp, power, util, clock
def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
"""Function to log the min/max metircs and the associated rank over a time period
It finds the slowest and fastest rank among all ranks. It should be
called by all ranks, but only rank-0 prints the analysis
At the end it checks, if the straggler detector should
remain active or if it should be deactivated.
Args:
total_flops (float, optional): The theoretical flops over the period. Defaults to 0.0.
log_interval (int, optional): The training interval over which reporting is called(ms)
Defaults to 0.
Returns:
bool: True if reported, else False
"""
ret = False
if not self._off and total_flops > 0.0 and log_interval > 0:
elapsed, btime, temp, power, util, clock = self.elapsed() # get raw time
# btime (get_batch time is max in the iteration)
ptime = elapsed / (log_interval * 1.0) # avg per iteration elapsed time, ms
api_flops = total_flops / (log_interval * 1.0) # avg per iteration flops, ms
apir_flops = api_flops / (
ptime * 10 ** 9 * self.world
) # this is avg per iteration this rank's thruput, TFLOP/s (note 10**9),
et_flops = apir_flops / self.amp # Estimated TFLOPs, not tracing backward
o_dt = self._min_max(
ptime, btime, float(temp), float(power), float(util), float(clock), et_flops,
)
if self.rank == 0 and o_dt is not None and o_dt.aflops is not None:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
f"MnPwr/Rnk: {o_dt.min_power} | "
f"MxPwr/Rnk: {o_dt.max_power} | "
f"MnTmp/Rnk: {o_dt.min_temp} | "
f"MxTmp/Rnk: {o_dt.max_temp} | "
f"MnUtl/Rnk: {o_dt.min_util} | "
f"MxUtl/Rnk: {o_dt.max_util} | "
f"MnClk/Rnk: {o_dt.min_clock} | "
f"MxClk/Rnk: {o_dt.max_clock} | "
f"MnDRtt/Rnk: {o_dt.min_btime} | "
f"MxDRtt/Rnk: {o_dt.max_btime} | "
f"MnEtpt/Rnk: {min_flops:.2f}TF/{min_frank} | "
f"MxEtpt/Rnk: {max_flops:.2f}TF/{max_frank}"
)
if self.mmcnt > 1 and self.mmcnt < self.world:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
logger.info(line)
ret = True
# Check/Communicate if tracking is turned off or on
self._check_toggle()
return ret
def _check_toggle(self) -> None:
"""Helper method to check if a request to toggle the collection state was made
It checks iof collection state toggle req was made via the server listening on
rank-0 since last call to report(). Called by report(). Calling this method
indirectly from report() is the only way to activate the change that is made
via rank-0
"""
# If no change just commnunicate the current
off = self._off
if self.rank == 0 and self.toggle:
off = not self._off
self.toggle = False
st = torch.tensor(off, dtype=torch.bool, device=self.dev)
torch.distributed.broadcast(st, 0) # Blocking
# save old switch
off = self._off
self._off = bool(st.item())
if off != self._off:
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
state = "ON"
else:
self.start = self.null_method
self.stop = self.null_method
state = "OFF"
if self.rank == 0:
logger.info(f"Toggling StragglerDetector State {state}")
def _handler(self) -> None:
"""Thread function for the controller.
It is a tcp-server that listens on a port. Uses HTTP protocol.
If connected to it using curl, it indicates a toggle of the
collection state. The actual toggling happens at the end of
calling report() when _check_toggle() is called.
"""
resp = f"HTTP/1.0 200 OK\r\nConnection: Close\r\nContent-length: "
if self.rank == 0:
state = "OFF" if self._off else "ON"
logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True and self.sock is not None:
try:
conn, _ = self.sock.accept()
_ = conn.recv(1024)
self.toggle = True
state = "ON" if self._off else "OFF"
msg = f"Will turn StragglerDetector {state} at next logging interval"
msg_len = len(msg)
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
logger.info(msg)
except Exception as err:
logger.error(f"Error in stragler handler.. {str(err)}")
return
def _controller(self):
"""Installs a controller listener that is used to toggle collection state.
Called from configure(). Ignored for all ranks other than rank-0
"""
try:
if self.rank == 0:
neth = "0.0.0.0"
netp = self.port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((neth, netp))
self.sock.listen(128)
self.ctrlr = threading.Thread(
target=self._handler, args=(), name="straggler", daemon=True
)
self.ctrlr.start()
except Exception as err:
logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
def _min_max(
self,
ptime: float,
btime: float,
temp: float,
power: float,
util: float,
clock: float,
flops: float,
) -> Union[_StragglerData, None]:
"""Helper function to find the min/max values
Args:
ptime (float): avg per iteration gpu time
btime (float): avg per iteration cpu time
temp (float): gpu temp at the time of reporting
power (float): gpu power at the time of reporting
util (float): gpu util at the time of reporting
clock (float): gpu clock at the time of reporting
flops (float): estimated flops for the rank
Returns:
Union[_StragglerData, None]: It contains the min/max of few metrics and the
corresponding rank it also has sorted list of
all (flops, rank) sorted by flops (aflops)
or returns None if collecton is disabled
"""
if self._off:
return None
# initialize output data object
o_dt = _StragglerData()
prof_data: Dict[str, Union[int, float]] = {}
data_list: List[Dict[str, Union[int, float]]] = []
prof_data["rank"] = self.rank
prof_data["time"] = ptime
prof_data["btime"] = btime
prof_data["temp"] = temp
prof_data["power"] = power
prof_data["util"] = util
prof_data["clock"] = clock
prof_data["flops"] = flops
if self.rank == 0:
data_list = [prof_data] * self.world
# this is blocking by default
torch.distributed.gather_object(prof_data, object_gather_list=data_list, dst=0)
if self.rank == 0:
min_ctime = min(data_list, key=lambda k: k["time"]) # elapsed
max_ctime = max(data_list, key=lambda k: k["time"]) # elapsed
min_cbatch = min(data_list, key=lambda k: k["btime"]) # batch time
max_cbatch = max(data_list, key=lambda k: k["btime"]) # batch time
min_ctemp = min(data_list, key=lambda k: k["temp"]) # temp
max_ctemp = max(data_list, key=lambda k: k["temp"]) # temp
min_cpower = min(data_list, key=lambda k: k["power"]) # power
max_cpower = max(data_list, key=lambda k: k["power"]) # power
min_cutil = min(data_list, key=lambda k: k["util"]) # gpu util
max_cutil = max(data_list, key=lambda k: k["util"]) # gpu util
min_cclock = min(data_list, key=lambda k: k["clock"]) # gpu clock
max_cclock = max(data_list, key=lambda k: k["clock"]) # gpu clock
min_val = min_ctime["time"]
min_rank = min_ctime["rank"]
max_val = max_ctime["time"]
max_rank = max_ctime["rank"]
o_dt.min_elapsed = _ValueWithRank(min_val, int(min_rank), "ms")
o_dt.max_elapsed = _ValueWithRank(max_val, int(max_rank), "ms")
min_val = min_cbatch["btime"]
min_rank = min_cbatch["rank"]
max_val = max_cbatch["btime"]
max_rank = max_cbatch["rank"]
o_dt.min_btime = _ValueWithRank(min_val, int(min_rank), "ms")
o_dt.max_btime = _ValueWithRank(max_val, int(max_rank), "ms")
min_val = min_ctemp["temp"]
min_rank = min_ctemp["rank"]
max_val = max_ctemp["temp"]
max_rank = max_ctemp["rank"]
o_dt.min_temp = _ValueWithRank(min_val, int(min_rank), "C")
o_dt.max_temp = _ValueWithRank(max_val, int(max_rank), "C")
min_val = min_cpower["power"]
min_rank = min_cpower["rank"]
max_val = max_cpower["power"]
max_rank = max_cpower["rank"]
o_dt.min_power = _ValueWithRank(min_val, int(min_rank), "W")
o_dt.max_power = _ValueWithRank(max_val, int(max_rank), "W")
min_val = min_cutil["util"]
min_rank = min_cutil["rank"]
max_val = max_cutil["util"]
max_rank = max_cutil["rank"]
o_dt.min_util = _ValueWithRank(min_val, int(min_rank), "%")
o_dt.max_util = _ValueWithRank(max_val, int(max_rank), "%")
min_val = min_cclock["clock"]
min_rank = min_cclock["rank"]
max_val = max_cclock["clock"]
max_rank = max_cclock["rank"]
o_dt.min_clock = _ValueWithRank(min_val, int(min_rank), "MHz")
o_dt.max_clock = _ValueWithRank(max_val, int(max_rank), "MHz")
o_dt.aflops = [
_ValueWithRank(d.get("flops", 0.0), int(d.get("rank", -1)))
for _, d in enumerate(data_list)
]
o_dt.aflops.sort(key=lambda val_with_rank: val_with_rank()[0])
# wait for everyone here
torch.distributed.barrier()
return o_dt
@property
def enabled(self) -> bool:
"""Can be called to check the enabled state of the instance
Note:
After the request to toggle the state, the
actual state change happens at end of call
to report()
"""
return not self._off
@property
def configured(self) -> bool:
"""Can be called to check if the the instance is already configured
Returns:
bool: returns True if configure was called and was a success, else False
"""
return StragglerDetector._configured
@property
def my_rank(self):
"""Can be called to get configured rank of this instance
Returns:
int: Configured rank for this instance
"""
return self.rank
@property
def world_size(self) -> int:
"""Can be called to get configured world of this instance
Returns:
int: World size configured for this instance
"""
return self.world
def null_method(self) -> None:
"""Default method to initialize start/stop method ptrs"""
pass
def __enter__(self) -> "StragglerDetector":
"""Define context/instance entry
Returns:
StragglerDetector: the instance
"""
self.start()
return self
def __call__(self, bdata: bool = False) -> "StragglerDetector":
"""Callable for the instance. Set context state,
Useful when the context is used for cpu timers only when bdata=True
Args:
bdata (bool, optional): when true, only enables cpu timers. Defaults to False.
Returns:
StragglerDetector: the instance
"""
self.bdata = bdata
return self
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> bool:
"""Define context/instance exit, calls the stop method
Args:
ex_type (Optional[Type[BaseException]]): Exception type
ex_val (Optional[BaseException]): _description_
ex_tb (Optional[TracebackType]): _description_
Returns:
bool: True if the exception was handled
"""
# Should not suppress errors even if turned off
if ex_type is not None:
err = traceback.format_exception(ex_type, ex_val, ex_tb)
logger.warning(f"{str(ex_val)}\n{err}")
self.stop()
return False
# Singleton, global visibility
__straggler__ = StragglerDetector()
"""StragglerDetector: private module variable, not be directly accessed
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
def add_modelopt_args(parser):
"""Add additional arguments for using TensorRT Model Optimizer (modelopt) features."""
group = parser.add_argument_group(title="modelopt-generic")
group.add_argument(
"--export-legacy-megatron",
action="store_true",
help="Export a legacy megatron-lm checkpoint.",
)
group.add_argument(
"--export-te-mcore-model",
action="store_true",
help="Export a megatron-core transformer-engine checkpoint.",
)
group.add_argument(
"--export-quant-cfg",
type=str,
default=None,
choices=["int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4", "None"],
help="Specify a quantization config from the supported choices.",
)
return parser
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from pathlib import Path
from typing import Optional, Dict
from megatron.core import dist_checkpointing
from megatron.training import get_args
from megatron.training.checkpointing import _load_base_checkpoint, load_checkpoint
from megatron.training.utils import print_rank_0, unwrap_model
try:
from modelopt.torch.opt.plugins import (
get_sharded_modelopt_state,
restore_modelopt_state_metadata,
)
except ImportError as e:
raise ImportError("Required `\"nvidia-modelopt[torch]\"` is not installed!") from e
def load_modelopt_state(load_dir: Optional[str] = None) -> Dict:
"""Loading modelopt_state without a model.
If --use-dist-ckpt, we try to load from the sharded modelopt_state. This will not load the model
state_dict. Otherwise, if the checkpoint is not sharded, we load the base checkpoint (that
contains the model state as well) and extract the modelopt_state.
Args:
load_dir: optionally provide a different loading path
"""
args = get_args()
if load_dir is None:
load_dir = args.load
if args.use_dist_ckpt:
# Read the tracker file and set the iteration.
tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
# If no tracker file, assuming that it is a .nemo checkpoint.
if not os.path.isfile(tracker_filename):
sharded_load_dir = Path(load_dir) / "model_weights"
else:
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
sharded_load_dir = Path(load_dir) / 'iter_{:07d}'.format(iteration)
except ValueError:
sharded_load_dir = Path(load_dir) / metastring
modelopt_state_dir = sharded_load_dir / "modelopt_state"
if modelopt_state_dir.exists():
print_rank_0("Loading sharded modelopt_state ({})".format(modelopt_state_dir))
modelopt_state = restore_modelopt_state_metadata(
dist_checkpointing.load(
get_sharded_modelopt_state(args.num_layers), modelopt_state_dir,
)
)
return modelopt_state
else:
print_rank_0(
"sharded modelopt_state ({}) does not exist!".format(modelopt_state_dir)
)
return {}
else:
print_rank_0("Loading modelopt_state from base checkpoint ({})".format(load_dir))
try:
state_dict, _, _ = _load_base_checkpoint(args.load, rank0=False)
except Exception:
print_rank_0("Failed to load base checkpoint via megatron _load_base_checkpoint!")
return {}
if state_dict is None:
return {}
return state_dict.get("modelopt_state", {})
def load_modelopt_checkpoint(
model,
optimizer=None,
opt_param_scheduler=None,
strict: bool = True,
additional_sharded_prefix: str = "model.",
load_arg: str = "load",
) -> None:
"""Load a sharded (untar .nemo or megatron --use-dist-ckpt) or unsharded checkpoint.
Essentially, the function is detecting whether the checkpoint is a .nemo sharded checkpoint.
If so, we load the sharded state_dict with additional_sharded_prefix `model.`.
This additional prefix is tha artifact of the lightning module wrapper. Once the sharded
state_dict is loaded, we use a state_dict pre_hook to pop this additional prefix (`model.`)
from all state_dict keys.
If this is not a .nemo sharded checkpoint, then this function will simply call
load_checkpoint. See megatron.checkpointing.load_checkpoint for explanation.
Args:
additional_sharded_prefix: append additional prefix to align the sharded checkpoint keys.
When loading an .nemo sharded checkpoint, this is usually `model.`. Otherwise, this is
typically an empty string.
"""
def _remove_prefix_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs,
):
"""Pytorch state_dict pre_hook to remove prefix of the state_dict keys."""
if additional_sharded_prefix is None:
return
key_rewrite_list = []
for key, _ in state_dict.items():
if key.startswith(additional_sharded_prefix):
key_rewrite_list.append(key)
for old_key in key_rewrite_list:
new_key = old_key[len(additional_sharded_prefix) :]
state_dict[new_key] = state_dict.pop(old_key)
args = get_args()
load_dir = getattr(args, load_arg)
sharded_load_dir = Path(load_dir) / "model_weights"
if sharded_load_dir.exists() and optimizer is None and opt_param_scheduler is None:
unwrapped_model = unwrap_model(model)
# Set this attribute will alter the sharded_offsets of transformer_block.
unwrapped_model[0].decoder.config.non_homogeneous_layers = False
sharded_state_dict = unwrapped_model[0].sharded_state_dict(prefix=additional_sharded_prefix)
if additional_sharded_prefix:
unwrapped_model[0]._register_load_state_dict_pre_hook(
_remove_prefix_state_dict_pre_hook
)
unwrapped_model[0].load_state_dict(
dist_checkpointing.load(sharded_state_dict, sharded_load_dir)
)
# Set the attribute to True such that by-default we are storing the heterogenous arch.
unwrapped_model[0].decoder.config.non_homogeneous_layers = True
else:
_ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=strict, load_arg=load_arg)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""ModelOpt GPT model provider."""
import modelopt.torch.opt as mto
from megatron.core.inference.gpt.model_specs import get_gpt_layer_modelopt_spec
from megatron.core.inference.gpt.state_dict_hooks import (
mcore_gpt_load_legacy_state_dict_pre_hook,
mcore_gpt_load_te_state_dict_pre_hook,
)
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.parallel_state import get_tensor_model_parallel_rank
from megatron.core.transformer.spec_utils import import_module
from megatron.inference.checkpointing import load_modelopt_state
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
def model_provider(pre_process=True, post_process=True, parallel_output=True) -> MCoreGPTModel:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
parallel_output (bool): whether to allgather the output logits? This must be
True if `model_provider` is called in text_generation_server.
Returns:
MCoreGPTModel: The returned model
"""
args = get_args()
print_rank_0("building GPT model ...")
# ModelOpt by default assumes none homogenous layers. This affect the storage format of the sharded checkpoint.
config = core_transformer_config_from_args(args)
config.non_homogeneous_layers = True
if args.use_legacy_models:
raise ValueError(
"ModelOpt integration only support MCore models. Use --use-mcore-modules instead."
)
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
transformer_layer_spec = get_gpt_layer_modelopt_spec(
remap_te_layernorm=args.export_te_mcore_model, qk_layernorm=False,
)
model_type = MCoreGPTModel
model_kwargs = {
"config": config,
"transformer_layer_spec": transformer_layer_spec,
"vocab_size": args.padded_vocab_size,
"max_sequence_length": args.max_position_embeddings,
"pre_process": pre_process,
"post_process": post_process,
"fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
"parallel_output": parallel_output,
"share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
"position_embedding_type": args.position_embedding_type,
"rotary_percent": args.rotary_percent,
}
model = model_type(**model_kwargs)
# Load modelopt_state
modelopt_state = load_modelopt_state() if args.load else {}
if modelopt_state:
model = mto.restore_from_modelopt_state(model, modelopt_state)
# Register some load_state_dict prehooks to handle some known state_dict key mismatch.
# (legacy <-> modelopt) and (default te <-> modelopt)
if args.export_legacy_megatron:
model._register_load_state_dict_pre_hook(mcore_gpt_load_legacy_state_dict_pre_hook)
if args.export_te_mcore_model:
model._register_load_state_dict_pre_hook(mcore_gpt_load_te_state_dict_pre_hook)
# Print models on all pp ranks.
if get_tensor_model_parallel_rank() == 0:
print(str(model))
return model
<!-- coding=utf-8-->
<!-- Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .api import (
generate,
generate_and_post_process,
beam_search_and_post_process)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Inference API."""
import torch
from megatron.core import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage,
beam_search_and_return_on_first_stage)
from .tokenization import (
tokenize_prompts,
detokenize_generations)
from .forward_step import ForwardStep
def generate_and_post_process(model,
forward_step=ForwardStep,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1,
return_logits=False):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs, logits = generate(
model,
forward_step=forward_step,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]
if return_logits:
assert(tokens_to_generate == 0)
assert(mpu.get_pipeline_model_parallel_world_size() == 1)
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens, logits
else:
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
forward_step=None,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling, top_p_decay, top_p_bound,
temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol,
stop_on_eol,
prevent_newline_after_colon,
random_seed]
values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
top_p_decay = values_float_tensor[4].item()
top_p_bound = values_float_tensor[5].item()
temperature = values_float_tensor[6].item()
add_BOS = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
stop_on_double_eol = bool(values_float_tensor[9].item())
stop_on_eol = bool(values_float_tensor[10].item())
prevent_newline_after_colon = bool(values_float_tensor[11].item())
random_seed = int(values_float_tensor[12].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, forward_step, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon)
def beam_search_and_post_process(model,
forward_step=ForwardStep,
prompts=None,
tokens_to_generate=0,
beam_size=0,
add_BOS=False,
stop_token=50256,
num_return_gen=1,
length_penalty=1,
prevent_newline_after_colon=False):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, scores = beam_search(model,
forward_step=forward_step,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size=beam_size,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=num_return_gen,
length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
scores = scores.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, scores
return None
def beam_search(model, forward_step, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
add_BOS,
stop_token,
num_return_gen,
length_penalty,
prevent_newline_after_colon]
values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item())
stop_token = int(values_float_tensor[3].item())
num_return_gen = int(values_float_tensor[4].item())
length_penalty = values_float_tensor[5].item()
prevent_newline_after_colon = values_float_tensor[6].item()
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, forward_step, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## from huggingface beam search
class BeamHypotheses(object):
def __init__(self, num_beams, length_penalty=1.0, early_stopping=False):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs, length):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / length ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Communications utilities."""
import torch
from megatron.core import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron.training import get_args
from megatron.core import mpu, InferenceParams
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_length):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_length)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def _forward(self, tokens, position_ids, attention_mask):
return self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params)
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return self._with_pipelining_forward_step(tokens,
position_ids,
attention_mask,
micro_batch_size)
return self._no_pipelining_forward_step(tokens,
position_ids,
attention_mask)
def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
self.model.set_input_tensor(recv_buffer)
output_tensor = self._forward(tokens, position_ids, attention_mask)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(self, tokens, position_ids, attention_mask,
recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = self._forward_step_helper(tokens, position_ids,
attention_mask, recv_buffer=recv_buffer)
# Update the sequence length offset.
self.inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = self._forward_step_helper(tokens2use, position_ids2use, attention_mask, recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
self.inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
self.inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
self.inference_params.batch_size_offset = 0
return logits
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron.training import get_args, get_tokenizer
from megatron.core import mpu
from megatron.training.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
from .beam_utils import BeamHypotheses
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Args:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Returns:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
if max_prompt_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_prompt_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = ForwardStep(model, batch_size, max_prompt_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_prompt_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs, logits
def generate_tokens_probs_and_return_on_first_stage(
model, forward_step, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=True
):
"""Main token generation function.
Args:
model: no interleaving is supported.
forward_step (ForwardStep): Class for running the model forward step.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
prevent_newline_after_colon: if True, it will disable generating new line \n after :
Note: Outside of model, other parameters only need to be available on
rank 0.
Returns: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
if max_sequence_length > args.max_position_embeddings:
raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
if max_sequence_length * batch_size > args.max_tokens_to_oom:
raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
# forward step.
forward_step = forward_step(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
if top_p > 0.0 and top_p_decay > 0.0:
top_p = top_p * top_p_decay
if top_p_bound > 0.0:
top_p = max(top_p, top_p_bound)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, None
def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = forward_step(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size, length_penalty)
best_batches = None
done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
tokens)
# set inference key values to make it consistent with best beam index
best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
if mpu.is_pipeline_last_stage():
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
return tokens, scores
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tokenization utilities."""
import torch
from megatron.training import get_tokenizer, get_args
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
args = get_args()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
if args.tokenizer_type in ['SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'Llama2Tokenizer',
'MistralTokenizer']:
word = tokenizer.decoder[token]
elif args.tokenizer_type == 'Llama3Tokenizer':
word = tokenizer.decode([token])
elif args.tokenizer_type == 'NullTokenizer':
word = str(token)
else:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.tensor(prompts_tokens, dtype=torch.long, device='cuda')
prompts_length_tensor = torch.tensor(prompts_length, dtype=torch.long, device='cuda')
return prompts_tokens_tensor, prompts_length_tensor
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron.training import get_args
from megatron.inference.text_generation import generate_and_post_process
from megatron.inference.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.tensor([GENERATE_NUM], dtype=torch.long, device='cuda')
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.tensor([BEAM_NUM], dtype=torch.long, device='cuda')
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
if not "prompts" in request.get_json():
return "prompts argument required", 400
if "max_len" in request.get_json():
return "max_len is no longer used. Replace with tokens_to_generate", 400
if "sentences" in request.get_json():
return "sentences is no longer used. Replace with prompts", 400
prompts = request.get_json()["prompts"]
if not isinstance(prompts, list):
return "prompts is not a list of strings", 400
if len(prompts) == 0:
return "prompts is empty", 400
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
top_p_decay = 0.0
if "top_p_decay" in request.get_json():
top_p_decay = request.get_json()["top_p_decay"]
if not (type(top_p_decay) == float):
return "top_p_decay must be a positive float less than or equal to 1.0"
if top_p == 0.0:
return "top_p_decay cannot be set without top_p"
if not (0 <= top_p_decay <= 1.0):
return "top_p_decay must be less than or equal to 1.0"
top_p_bound = 0.0
if "top_p_bound" in request.get_json():
top_p_bound = request.get_json()["top_p_bound"]
if not (type(top_p_bound) == float):
return "top_p_bound must be a positive float less than or equal to top_p"
if top_p == 0.0:
return "top_p_bound cannot be set without top_p"
if not (0.0 < top_p_bound <= top_p):
return "top_p_bound must be greater than 0 and less than top_p"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
prevent_newline_after_colon = False
if "prevent_newline_after_colon" in request.get_json():
prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
if not isinstance(prevent_newline_after_colon, bool):
return "prevent_newline_after_colon must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
no_log = False
if "no_log" in request.get_json():
no_log = request.get_json()["no_log"]
if not isinstance(no_log, bool):
return "no_log must be a boolean value"
beam_width = None
if "beam_width" in request.get_json():
beam_width = request.get_json()["beam_width"]
if not isinstance(beam_width, int):
return "beam_width must be integer"
if beam_width < 1:
return "beam_width must be an integer > 1"
if len(prompts) > 1:
return "When doing beam_search, batch size must be 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
try:
if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_width,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon
)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return ve.args[0]
print("end time: ", datetime.datetime.now())
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url, port):
self.app.run(url, threaded=True, debug=False, port=port)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment