Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -53,35 +53,63 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: ...@@ -53,35 +53,63 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
Returns: Returns:
int: The number of layers to be built for the current pipeline stage. int: The number of layers to be built for the current pipeline stage.
""" """
if config.first_pipeline_num_layers is not None or config.last_pipeline_num_layers is not None: if (
assert ( config.num_layers_in_first_pipeline_stage is not None
parallel_state.get_virtual_pipeline_model_parallel_world_size() is None or config.num_layers_in_last_pipeline_stage is not None
), "Uneven number of layer not compatible with interleaved pipeline schedule" ):
assert not (
config.account_for_embedding_in_pipeline_split
or config.account_for_loss_in_pipeline_split
), " \
Does not support standalone embedding stage and standalone loss stage with uneven pp"
# Number of layers to distribute over rest of pipeline stages # Number of layers to distribute over rest of pipeline stages
layers_to_distribute = config.num_layers layers_to_distribute = config.num_layers
# Number of pipeline stages left for distributing transformer layers # Number of pipeline stages left for distributing transformer layers
pipeline_stages_left = parallel_state.get_pipeline_model_parallel_world_size() pipeline_stages_left = parallel_state.get_pipeline_model_parallel_world_size()
if config.first_pipeline_num_layers is not None: # If the uneven first (last) pipeline stage is enabled, remove the specified number
layers_to_distribute -= config.first_pipeline_num_layers # of layers to calculate the number of layers on each middle pipeline stage.
if config.num_layers_in_first_pipeline_stage is not None:
layers_to_distribute -= config.num_layers_in_first_pipeline_stage
pipeline_stages_left -= 1 pipeline_stages_left -= 1
if parallel_state.is_pipeline_first_stage():
return config.first_pipeline_num_layers
if config.last_pipeline_num_layers is not None: if config.num_layers_in_last_pipeline_stage is not None:
layers_to_distribute -= config.last_pipeline_num_layers layers_to_distribute -= config.num_layers_in_last_pipeline_stage
pipeline_stages_left -= 1 pipeline_stages_left -= 1
if parallel_state.is_pipeline_last_stage():
return config.last_pipeline_num_layers
assert ( assert (
layers_to_distribute % pipeline_stages_left == 0 layers_to_distribute % pipeline_stages_left == 0
), "With uneven pipelineing the left over layers must be divisible by left over stages" ), "With uneven pipelineing the left over layers must be divisible by left over stages"
num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
# If the uneven first (last) pipeline stage is enabled, return the specified number
# of layers for all virtual pipeline parallel stages within the first (last) pipeline
# parallel stage.
if (
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
and config.num_layers_in_first_pipeline_stage is not None
):
num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage
if (
parallel_state.is_pipeline_last_stage(ignore_virtual=True)
and config.num_layers_in_last_pipeline_stage is not None
):
num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
else: else:
pipeline_ranks = config.pipeline_model_parallel_size # Include the embedding layer and loss layer into pipeline parallelism partition
num_layers_per_pipeline_rank = config.num_layers // pipeline_ranks num_layers = config.num_layers
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
assert (
num_layers % config.pipeline_model_parallel_size == 0
), "num_layers should be divisible by pipeline_model_parallel_size"
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Interleaved pipeline parallelism: # Interleaved pipeline parallelism:
...@@ -95,9 +123,11 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: ...@@ -95,9 +123,11 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
# layers to stages like (each list is a model chunk): # layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
assert (
num_layers_per_pipeline_rank % vp_size == 0
), "num_layers_per_pipeline_rank should be divisible by vp_size"
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
num_layers_to_build = num_layers_per_virtual_rank num_layers_to_build = num_layers_per_virtual_rank
...@@ -105,9 +135,19 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: ...@@ -105,9 +135,19 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
else: else:
# Non-interleaved pipeline parallelism: # Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
num_layers_to_build = num_layers_per_pipeline_rank num_layers_to_build = num_layers_per_pipeline_rank
# The embedding (or loss) layer cannot function as a standalone transformer layer
# Reduce the number of layers to construct by 1 on the first (or last) stage if the
# embedding (or loss) layer is included in the pipeline parallelism partition and placement.
if parallel_state.is_pipeline_first_stage() and config.account_for_embedding_in_pipeline_split:
num_layers_to_build -= 1
assert num_layers_to_build >= 0, "Not enough layers in the first virtual pipeline stage"
if parallel_state.is_pipeline_last_stage() and config.account_for_loss_in_pipeline_split:
num_layers_to_build -= 1
assert num_layers_to_build >= 0, "Not enough layers in the last virtual pipeline stage"
return num_layers_to_build return num_layers_to_build
...@@ -242,7 +282,7 @@ class TransformerBlock(MegatronModule): ...@@ -242,7 +282,7 @@ class TransformerBlock(MegatronModule):
] ]
) )
# @TODO: add back standalone_embedding_stage (see issue #293) # @TODO: add back account_for_embedding_in_pipeline_split (see issue #293)
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior # self.post_process and self.post_layer_norm guide this behavior
if self.submodules.layer_norm and self.post_process and self.post_layer_norm: if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
...@@ -404,6 +444,7 @@ class TransformerBlock(MegatronModule): ...@@ -404,6 +444,7 @@ class TransformerBlock(MegatronModule):
attention_bias: Tensor = None, attention_bias: Tensor = None,
inference_params: InferenceParams = None, inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None, packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
): ):
""" """
Perform the forward pass through the transformer block. Perform the forward pass through the transformer block.
...@@ -436,6 +477,10 @@ class TransformerBlock(MegatronModule): ...@@ -436,6 +477,10 @@ class TransformerBlock(MegatronModule):
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Update the inference parameters with the current batch size in case it is variable
if inference_params and not self.training:
inference_params.current_batch_size = hidden_states.size(1)
# Viewless tensor. # Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch # - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()' # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
...@@ -512,6 +557,7 @@ class TransformerBlock(MegatronModule): ...@@ -512,6 +557,7 @@ class TransformerBlock(MegatronModule):
attention_bias=attention_bias, attention_bias=attention_bias,
inference_params=inference_params, inference_params=inference_params,
packed_seq_params=packed_seq_params, packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
) )
else: else:
# CUDA graph replay for layer `l_no` and microbatch # CUDA graph replay for layer `l_no` and microbatch
...@@ -576,7 +622,10 @@ class TransformerBlock(MegatronModule): ...@@ -576,7 +622,10 @@ class TransformerBlock(MegatronModule):
non_homogeneous_layers = metadata is not None and metadata.get( non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False 'non_homogeneous_layers', False
) )
if self.config.num_moe_experts is not None: if isinstance(self.config.moe_layer_freq, int):
if self.config.moe_layer_freq > 1:
non_homogeneous_layers = True
elif isinstance(self.config.moe_layer_freq, list):
non_homogeneous_layers = True non_homogeneous_layers = True
sharded_state_dict = {} sharded_state_dict = {}
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
...@@ -25,14 +26,22 @@ class TransformerConfig(ModelParallelConfig): ...@@ -25,14 +26,22 @@ class TransformerConfig(ModelParallelConfig):
num_layers: int = 0 num_layers: int = 0
"""Number of transformer layers in a transformer block.""" """Number of transformer layers in a transformer block."""
first_pipeline_num_layers: int = None num_layers_in_first_pipeline_stage: Optional[int] = None
"""Number of transformer layers on first pipeline stage. """Number of transformer layers on first pipeline stage.
None implies equal layer division across PP ranks.""" None implies equal layer division across PP ranks."""
last_pipeline_num_layers: int = None num_layers_in_last_pipeline_stage: Optional[int] = None
"""Number of transformer layers on last pipeline stage. """Number of transformer layers on last pipeline stage.
None implies equal layer division across PP ranks.""" None implies equal layer division across PP ranks."""
account_for_embedding_in_pipeline_split: bool = False
"""If set, the embedding layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
account_for_loss_in_pipeline_split: bool = False
"""If set, the loss layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
hidden_size: int = 0 hidden_size: int = 0
"""Transformer hidden size.""" """Transformer hidden size."""
...@@ -45,14 +54,17 @@ class TransformerConfig(ModelParallelConfig): ...@@ -45,14 +54,17 @@ class TransformerConfig(ModelParallelConfig):
If attention backend is local we use the local pytorch implementation in mcore. If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """ Users can specify exact backend by changing this config. """
num_query_groups: int = None softmax_scale: Optional[float] = None
"""Softmax scale for attention scaling."""
num_query_groups: Optional[int] = None
"""Number of query groups for group query attention. If None, normal attention is used.""" """Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: int = None ffn_hidden_size: Optional[int] = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size """Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
if not provided.""" if not provided."""
kv_channels: int = None kv_channels: Optional[int] = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size // """Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided.""" num_attention_heads if not provided."""
...@@ -93,7 +105,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -93,7 +105,7 @@ class TransformerConfig(ModelParallelConfig):
"""Store the input of MLP activation function in FP8 for backprop to save memory. """Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation.""" The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: int = None num_moe_experts: Optional[int] = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None """Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE.""" for no MoE."""
...@@ -105,7 +117,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -105,7 +117,7 @@ class TransformerConfig(ModelParallelConfig):
"""If not None, then will use sliding window attention. The size of the window is specified by """If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size".""" the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
normalization: bool = "LayerNorm" normalization: str = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`.""" """Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False qk_layernorm: bool = False
...@@ -124,13 +136,13 @@ class TransformerConfig(ModelParallelConfig): ...@@ -124,13 +136,13 @@ class TransformerConfig(ModelParallelConfig):
#################### ####################
# initialization # initialization
#################### ####################
init_method: Callable = None init_method: Optional[Callable] = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that """Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std.""" mean=0.0 and std=init_method_std."""
output_layer_init_method: Callable = None output_layer_init_method: Optional[Callable] = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None, """Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).""" init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
...@@ -139,6 +151,12 @@ class TransformerConfig(ModelParallelConfig): ...@@ -139,6 +151,12 @@ class TransformerConfig(ModelParallelConfig):
"""Standard deviation of the zero mean normal for the default initialization method, not used if """Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided.""" init_method and output_layer_init_method are provided."""
init_model_with_meta_device: bool = False
"""
If True, initializes the model with the meta device. This is helpful for
training of very large models. This feature is only works when custom fsdp is turned on.
"""
#################### ####################
# mixed-precision # mixed-precision
#################### ####################
...@@ -176,7 +194,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -176,7 +194,7 @@ class TransformerConfig(ModelParallelConfig):
#################### ####################
# activation recomputation # activation recomputation
#################### ####################
recompute_granularity: str = None recompute_granularity: Optional[str] = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective' """Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed. activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation These memory intensive activations are also less compute intensive which makes activation
...@@ -186,7 +204,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -186,7 +204,7 @@ class TransformerConfig(ModelParallelConfig):
If set, must be 'selective' or 'full'. 'selective' always uses all layers. If set, must be 'selective' or 'full'. 'selective' always uses all layers.
""" """
recompute_method: str = None recompute_method: Optional[str] = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the """Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for each divided chunk at the specified granularity. block will recompute the input activations for
...@@ -194,19 +212,19 @@ class TransformerConfig(ModelParallelConfig): ...@@ -194,19 +212,19 @@ class TransformerConfig(ModelParallelConfig):
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'.""" layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: int = None recompute_num_layers: Optional[int] = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in """When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing.""" 'selective' activation checkpointing."""
distribute_saved_activations: bool = None distribute_saved_activations: Optional[bool] = None
"""If True, distribute recomputed activations across the model parallel group.""" """If True, distribute recomputed activations across the model parallel group."""
#################### ####################
# fp8 related # fp8 related
#################### ####################
fp8: str = None fp8: Optional[str] = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined """If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors.""" activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
...@@ -245,7 +263,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -245,7 +263,7 @@ class TransformerConfig(ModelParallelConfig):
#################### ####################
# MoE related # MoE related
#################### ####################
moe_shared_expert_intermediate_size: int = None moe_shared_expert_intermediate_size: Optional[int] = None
"""Shared expert total ffn hidden size. """Shared expert total ffn hidden size.
It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
there are multiple shared experts. there are multiple shared experts.
...@@ -255,56 +273,104 @@ class TransformerConfig(ModelParallelConfig): ...@@ -255,56 +273,104 @@ class TransformerConfig(ModelParallelConfig):
"""Enable overlapping between shared expert computations and dispatcher communications. """Enable overlapping between shared expert computations and dispatcher communications.
Without this, the shared epxerts execute after the routed experts.""" Without this, the shared epxerts execute after the routed experts."""
moe_layer_freq: int = 1 moe_layer_freq: Union[int, List[int]] = 1
"""Frequency between MoE layers and Dense layers. Accepts either: """Frequency between MoE layers and Dense layers. Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers. - An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
- A string containing a Python list expression that defines a custom pattern, e.g.: - A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0]
where 1 indicates an expert layer and 0 indicates a dense layer."""
moe_ffn_hidden_size: int = None moe_ffn_hidden_size: Optional[int] = None
"""MoE Feed-Forward Network hidden size""" """MoE Feed-Forward Network hidden size"""
moe_router_load_balancing_type: str = "aux_loss" moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load """The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the loss used in DeepSeekV2,
algorithm used in S-BASE, and "none" implies no load balancing.""" which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss"."""
moe_router_topk: int = 2 moe_router_topk: int = 2
"""Number of experts to route to for each token.""" """Number of experts to route to for each token."""
moe_router_topk_limited_devices: Optional[int] = None
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""
moe_router_num_groups: Optional[int] = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""
moe_router_group_topk: Optional[int] = None
"""Number of selected groups for group-limited routing."""
moe_router_pre_softmax: bool = False moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. """Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
By default, softmax is done after top-k.""" By default, softmax is done after top-k."""
moe_router_topk_scaling_factor: Optional[float] = None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_router_score_function: str = "softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_enable_expert_bias: bool = False
"""TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
The routing decision is based on the sum of the routing scores and the expert bias.
See https://arxiv.org/abs/2408.15664 for details."""
moe_router_bias_update_rate: float = 1e-3
"""The expert bias is updated based on the number of assigned tokens to each expert
in a global batch, where the bias is increased for the experts with less assigned tokens
and decreased for the experts with more assigned tokens.
The default value 1e-3 is same as that used in DeepSeekV3."""
moe_grouped_gemm: bool = False moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms """When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
""" """
moe_use_legacy_grouped_gemm: bool = False
"""Use legacy GroupedMLP rather than TEGroupedMLP.
Note: The legacy one will be deprecated soon."""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss. moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.""" """Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended.""" """Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: float = None moe_input_jitter_eps: Optional[float] = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value.""" """Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False # TODO: Support token dropping. moe_token_dropping: bool = False
"""This feature involves selectively dropping and padding tokens for each expert to achieve a """This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False.""" currently unsupported so should remain False."""
moe_token_dispatcher_type: str = "allgather" moe_token_dispatcher_type: str = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'. """The type of token dispatcher to use. The default is 'allgather'.
Options are 'allgather' and 'alltoall'.""" Options are 'allgather','alltoall' and 'flex'."""
moe_enable_deepep: bool = False
"""[Experimental] Enable DeepEP for efficient token dispatching and combine in MoE models."""
moe_per_layer_logging: bool = False moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.""" """Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: float = None moe_expert_capacity_factor: Optional[float] = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token """moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
will be dropped. The default is None.""" will be dropped. The default is None."""
...@@ -322,10 +388,13 @@ class TransformerConfig(ModelParallelConfig): ...@@ -322,10 +388,13 @@ class TransformerConfig(ModelParallelConfig):
moe_layer_recompute: bool = False moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory.""" """Memory optimization: checkpointing moe_layer to save actiavtion memory."""
moe_permute_fusion: bool = False
"""Fuse token rearrangement ops during token dispatching."""
################## ##################
# Context Parallel # Context Parallel
################## ##################
cp_comm_type: Union[str, List[str]] = None cp_comm_type: Optional[Union[str, List[str]]] = None
"""Inter-gpu communication type for context parallelism. """Inter-gpu communication type for context parallelism.
str: all layers share same communication type. str: all layers share same communication type.
List[str]: each layer has its separate communication type. List[str]: each layer has its separate communication type.
...@@ -341,6 +410,30 @@ class TransformerConfig(ModelParallelConfig): ...@@ -341,6 +410,30 @@ class TransformerConfig(ModelParallelConfig):
and P2P communications in high-level CP groups (e.g., via IBLink). and P2P communications in high-level CP groups (e.g., via IBLink).
""" """
##################
# Cuda Graphs
##################
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""
cuda_graph_use_single_mempool: bool = False
"""When set to true, cudagraphs will be captured inside a single mempool, in which all
cudagraphs may only be used once per step. If false, cudagraphs may be reused across
microbatches. Enabling may reduce cudagraph memory overheads due to memory fragmentation,
however may greatly increase the number of cudagraphs created when the number of microbatches
is high."""
cuda_graph_retain_backward_graph: bool = False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""
cuda_graph_warmup_steps: int = 3
"""Number of warmup steps for CUDA graphs"""
external_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
#################### ####################
# miscellaneous # miscellaneous
#################### ####################
...@@ -351,18 +444,21 @@ class TransformerConfig(ModelParallelConfig): ...@@ -351,18 +444,21 @@ class TransformerConfig(ModelParallelConfig):
disable_parameter_transpose_cache: bool = False disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations.""" """When set to true, the parameter transposes are not cached for subsequent iterations."""
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""
external_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
config_logger_dir: str = "" config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir""" """When non-empty, dumps entry-point configs to config_logger_dir"""
flash_decode: bool = False flash_decode: bool = False
""" Use the optimized flash decoding kernel during inference. """ """ Use the optimized flash decoding kernel during inference. """
use_te_rng_tracker: bool = False
""" Whether to use the TE or MCore version of the RNG tracker. """
inference_rng_tracker: bool = False
""" Whether we should instantiate a separate RNG tracker for inference. """
use_custom_fsdp: bool = False
""" Whether to use custom fsdp for training. """
def __post_init__(self): def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization. """Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
...@@ -407,6 +503,16 @@ class TransformerConfig(ModelParallelConfig): ...@@ -407,6 +503,16 @@ class TransformerConfig(ModelParallelConfig):
if self.moe_ffn_hidden_size is None: if self.moe_ffn_hidden_size is None:
self.moe_ffn_hidden_size = self.ffn_hidden_size self.moe_ffn_hidden_size = self.ffn_hidden_size
if self.moe_enable_deepep:
if self.moe_token_dispatcher_type != "flex":
raise ValueError("DeepEP backend is only supported with flex token dispatcher.")
if self.moe_token_dispatcher_type == "flex":
if self.moe_pad_expert_input_to_capacity:
raise ValueError(
"Flex token dispatcher does not support moe_pad_expert_input_to_capacity"
)
if self.moe_shared_expert_intermediate_size is not None: if self.moe_shared_expert_intermediate_size is not None:
if self.moe_shared_expert_intermediate_size <= 0: if self.moe_shared_expert_intermediate_size <= 0:
raise ValueError( raise ValueError(
...@@ -422,13 +528,9 @@ class TransformerConfig(ModelParallelConfig): ...@@ -422,13 +528,9 @@ class TransformerConfig(ModelParallelConfig):
) )
if self.moe_expert_capacity_factor is not None: if self.moe_expert_capacity_factor is not None:
if self.moe_token_dispatcher_type not in ["alltoall", "alltoall_seq"]:
raise ValueError(
'moe_expert_capacity_factor only works with alltoall token dispatcher'
)
if self.moe_expert_capacity_factor < 0: if self.moe_expert_capacity_factor < 0:
self.moe_expert_capacity_factor = None self.moe_expert_capacity_factor = None
if self.moe_router_load_balancing_type not in ["aux_loss", "none"]: if self.moe_router_load_balancing_type not in ["aux_loss", "seq_aux_loss", "none"]:
raise ValueError( raise ValueError(
'moe_expert_capacity_factor only works with aux_loss or none load balancing' 'moe_expert_capacity_factor only works with aux_loss or none load balancing'
) )
...@@ -495,11 +597,125 @@ class TransformerConfig(ModelParallelConfig): ...@@ -495,11 +597,125 @@ class TransformerConfig(ModelParallelConfig):
f'false when sequence parallel is enabled: {self.sequence_parallel}' f'false when sequence parallel is enabled: {self.sequence_parallel}'
) )
if (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
) and (
self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split
):
raise ValueError(
'num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage cannot be'
'set at the same time with account_for_embedding_in_pipeline_split'
'and account_for_loss_in_pipeline_split'
)
if (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
):
pipeline_parallel_size = self.pipeline_model_parallel_size
num_layers = self.num_layers
if self.num_layers_in_first_pipeline_stage is not None:
if self.num_layers_in_first_pipeline_stage <= 0:
raise ValueError('num_layers_in_first_pipeline_stage must be larger than 0')
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_first_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f'number of layers at first stage: '
f'{self.num_layers_in_first_pipeline_stage}'
f'must be divisible by virtual pipeline'
f'parallel degree {self.virtual_pipeline_model_parallel_size}'
)
num_layers -= self.num_layers_in_first_pipeline_stage
pipeline_parallel_size -= 1
if self.num_layers_in_last_pipeline_stage is not None:
if self.num_layers_in_last_pipeline_stage <= 0:
raise ValueError('num_layers_in_last_pipeline_stage must be larger than 0')
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_last_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f'number of layers at last stage: '
f'{self.num_layers_in_last_pipeline_stage}'
f'must be divisible by virtual pipeline'
f'parallel degree {self.virtual_pipeline_model_parallel_size}'
)
num_layers -= self.num_layers_in_last_pipeline_stage
pipeline_parallel_size -= 1
if not num_layers % pipeline_parallel_size == 0:
raise ValueError(
f'number of layers at middle stage: {num_layers} must be divisible by'
f'the middle pipeline model parallel size {pipeline_parallel_size}'
)
if self.virtual_pipeline_model_parallel_size is not None: if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0: num_layers_per_middle_pipeline_rank = num_layers // pipeline_parallel_size
if (
not num_layers_per_middle_pipeline_rank
% self.virtual_pipeline_model_parallel_size
== 0
):
raise ValueError( raise ValueError(
f'num_layers: {self.num_layers} must be divisible by ' f'number of layers on each middle pipeline rank:'
f'virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}' f'{num_layers_per_middle_pipeline_rank} must be divisible by virtual'
f'pipeline parallel degree {self.virtual_pipeline_model_parallel_size}'
)
if self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split:
if self.virtual_pipeline_model_parallel_size is None:
pipeline_parallel_size = self.pipeline_model_parallel_size
if self.account_for_embedding_in_pipeline_split:
pipeline_parallel_size -= 1
if self.account_for_loss_in_pipeline_split:
pipeline_parallel_size -= 1
if not self.num_layers % pipeline_parallel_size == 0:
raise ValueError(
f'number of middle layers: {self.num_layers} must be divisible by '
f'middle pipeline_model_parallel_size {pipeline_parallel_size}'
)
else:
num_layers = self.num_layers
if self.account_for_embedding_in_pipeline_split:
num_layers += 1
if self.account_for_loss_in_pipeline_split:
num_layers += 1
if not num_layers % self.pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {num_layers} after enable'
f'account_for_embedding_in_pipeline_split or '
f'account_for_loss_in_pipeline_split must be divisible'
f'by pipeline_model_parallel_size '
f'{self.pipeline_model_parallel_size}'
)
num_layers_per_pipeline_rank = num_layers // self.pipeline_model_parallel_size
if (
not num_layers_per_pipeline_rank % self.virtual_pipeline_model_parallel_size
== 0
):
raise ValueError(
f'number of layers on each pipeline rank: {num_layers_per_pipeline_rank}'
f'(after enable account_for_embedding_in_pipeline_split or '
f'account_for_loss_in_pipeline_split) must be divisible by'
f'virtual_pipeline_model_parallel_size'
f'{self.virtual_pipeline_model_parallel_size}'
) )
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
...@@ -529,13 +745,19 @@ class TransformerConfig(ModelParallelConfig): ...@@ -529,13 +745,19 @@ class TransformerConfig(ModelParallelConfig):
if self.rotary_interleaved: if self.rotary_interleaved:
raise ValueError("rotary_interleaved does not work with apply_rope_fusion.") raise ValueError("rotary_interleaved does not work with apply_rope_fusion.")
from megatron.core.models.common.embeddings.rope_utils import HAVE_APPLY_ROPE_FUSION from megatron.core.models.common.embeddings.rope_utils import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
if not HAVE_APPLY_ROPE_FUSION: if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None:
raise ValueError( raise ValueError(
"apply_rope_fusion is not available. Please install TE >= 1.4 or Apex." "apply_rope_fusion is not available. Please install TE >= 1.4 or Apex."
) )
if self.multi_latent_attention:
raise ValueError("multi_latent_attention does not support apply_rope_fusion.")
if self.multi_latent_attention and self.rotary_interleaved: if self.multi_latent_attention and self.rotary_interleaved:
raise ValueError("rotary_interleaved does not work with multi_latent_attention.") raise ValueError("rotary_interleaved does not work with multi_latent_attention.")
...@@ -555,6 +777,12 @@ class TransformerConfig(ModelParallelConfig): ...@@ -555,6 +777,12 @@ class TransformerConfig(ModelParallelConfig):
"alltoall_seq dispatcher not support different TP size for MoE and Dense layer." "alltoall_seq dispatcher not support different TP size for MoE and Dense layer."
) )
if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid":
raise ValueError(
"Expert bias for aux-loss-free routing only supports sigmoid score function."
"Please set --moe-router-score-function sigmoid for sigmoid score function."
)
if self.num_moe_experts and self.fp8: if self.num_moe_experts and self.fp8:
# TE version below 1.7.0 will raise Error when handle zeros tokens for expert # TE version below 1.7.0 will raise Error when handle zeros tokens for expert
if not is_te_min_version("1.7.0.dev0"): if not is_te_min_version("1.7.0.dev0"):
...@@ -569,9 +797,52 @@ class TransformerConfig(ModelParallelConfig): ...@@ -569,9 +797,52 @@ class TransformerConfig(ModelParallelConfig):
f"but your version is {get_te_version()}." f"but your version is {get_te_version()}."
) )
if (
self.moe_router_topk == 1
and self.moe_router_score_function == 'softmax'
and not self.moe_router_pre_softmax
and self.moe_router_load_balancing_type != 'sinkhorn'
):
# Requires applying softmax before selecting the top-k when k is 1,
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")
if self.moe_router_group_topk:
if self.moe_router_topk_limited_devices:
raise ValueError(
"moe_router_topk_limited_devices is deprecated and replaced by "
"moe_router_group_topk and moe_router_num_groups."
)
if not self.moe_router_num_groups:
raise ValueError(
"When using group limited routing, moe_router_num_groups must be specified."
)
else:
assert self.num_moe_experts % self.moe_router_num_groups == 0, (
f"num_moe_experts ({self.num_moe_experts}) should be divisible by "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
assert self.moe_router_group_topk <= self.moe_router_num_groups, (
f"moe_router_group_topk ({self.moe_router_group_topk}) should be smaller than "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
elif self.moe_router_topk_limited_devices:
warnings.warn(
"moe_router_topk_limited_devices is deprecated. Use moe_router_group_topk and "
"moe_router_num_groups instead."
)
self.moe_router_group_topk = self.moe_router_topk_limited_devices
self.moe_router_num_groups = self.expert_model_parallel_size
if self.flash_decode and self.fp8: if self.flash_decode and self.fp8:
raise ValueError("FP8 inference is currently not support with flash decoding.") raise ValueError("FP8 inference is currently not support with flash decoding.")
if self.enable_cuda_graph:
if self.cpu_offloading:
raise ValueError("CUDA graphs not supported with CPU offloading.")
if self.recompute_granularity:
raise ValueError("CUDA graphs not supported with activation recomputation.")
if self.moe_token_dispatcher_type in ['allgather', 'alltoall_seq']: if self.moe_token_dispatcher_type in ['allgather', 'alltoall_seq']:
if self.variable_seq_lengths is True: if self.variable_seq_lengths is True:
raise ValueError( raise ValueError(
...@@ -579,6 +850,20 @@ class TransformerConfig(ModelParallelConfig): ...@@ -579,6 +850,20 @@ class TransformerConfig(ModelParallelConfig):
f"variable sequence length, please use alltoall dispatcher instead." f"variable sequence length, please use alltoall dispatcher instead."
) )
if self.moe_permute_fusion:
from megatron.core.transformer.moe.moe_utils import (
fused_permute,
fused_sort_chunks_by_index,
fused_unpermute,
)
if (
fused_permute is None
or fused_sort_chunks_by_index is None
or fused_unpermute is None
):
raise ValueError("fused permutation is not available. Please install TE >= 2.1.0.")
if self.cp_comm_type is not None: if self.cp_comm_type is not None:
if isinstance(self.cp_comm_type, list): if isinstance(self.cp_comm_type, list):
assert len(self.cp_comm_type) == self.num_layers, ( assert len(self.cp_comm_type) == self.num_layers, (
...@@ -590,6 +875,11 @@ class TransformerConfig(ModelParallelConfig): ...@@ -590,6 +875,11 @@ class TransformerConfig(ModelParallelConfig):
self.cp_comm_type, str self.cp_comm_type, str
), "Unsupported communication type for context parallelism!" ), "Unsupported communication type for context parallelism!"
assert (
self.pipeline_model_parallel_size > 0
), f"Pipeline model parallel size must be larger than 0 \
when enable --standalone-embedding-stage and --standalone-loss-stage"
@dataclass @dataclass
class MLATransformerConfig(TransformerConfig): class MLATransformerConfig(TransformerConfig):
...@@ -617,26 +907,32 @@ class MLATransformerConfig(TransformerConfig): ...@@ -617,26 +907,32 @@ class MLATransformerConfig(TransformerConfig):
v_head_dim: int = 128 v_head_dim: int = 128
"""Dimension of the head in the V projection.""" """Dimension of the head in the V projection."""
normalization: str = "RMSNorm"
"""Default normalization layer for MLA models is RMSNorm."""
rope_type: str = "yarn"
"""Type of RoPE to use. Default to yarn, options are rope and yarn."""
rotary_base: float = 10000 rotary_base: float = 10000
"""Rotary base for the rotary embeddings.""" """Rotary base for the rotary embeddings, used by rope and yarn."""
rotary_scaling_factor: float = 40 rotary_percent: float = 1.0
"""Rotary scaling factor for the rotary embeddings.""" """Rotary percent for the rotary embeddings, used by rope."""
normalization: str = "RMSNorm" rotary_scaling_factor: float = 40
"""Default normalization layer for MLA models is RMSNorm.""" """Rotary scaling factor for the rotary embeddings, used by yarn."""
max_position_embeddings: int = 163840 max_position_embeddings: int = 4096
"""Maximum position embeddings for the original model.""" """Maximum position embeddings for the original model, used by yarn."""
beta_fast: float = 32 beta_fast: float = 32
"""Beta fast for YaRN RoPE.""" """Beta fast for YaRN RoPE, used by yarn."""
beta_slow: float = 1 beta_slow: float = 1
"""Beta slow for YaRN RoPE.""" """Beta slow for YaRN RoPE, used by yarn."""
mscale: float = 0.707 mscale: float = 0.707
"""Mscale for YaRN RoPE in Multi-Latent Attention.""" """Mscale for YaRN RoPE in Multi-Latent Attention, used by yarn."""
mscale_all_dim: float = 0.707 mscale_all_dim: float = 0.707
"""Mscale all dimensions for YaRN RoPE in Multi-Latent Attention.""" """Mscale all dimensions for YaRN RoPE in Multi-Latent Attention, used by yarn."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from abc import ABC from abc import ABC
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.distributed
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.mapping import ShardedStateDict
...@@ -17,6 +18,159 @@ from megatron.core.transformer.transformer_config import TransformerConfig ...@@ -17,6 +18,159 @@ from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
def get_transformer_layer_offset(config: TransformerConfig):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if not parallel_state.is_inside_encoder():
pp_decoder_start = parallel_state.get_pipeline_model_parallel_decoder_start()
if pp_decoder_start is not None:
pipeline_rank = pipeline_rank - pp_decoder_start
if config.pipeline_model_parallel_size > 1:
if (
config.num_layers_in_first_pipeline_stage is not None
or config.num_layers_in_last_pipeline_stage is not None
):
# Calculate number of pipeline stages to distribute the remaining Transformer
# layers after deducting the Transformer layers in the first or the last stages
middle_pipeline_stages = config.pipeline_model_parallel_size
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
config.num_layers_in_first_pipeline_stage,
config.num_layers_in_last_pipeline_stage,
)
]
)
# Calculate layers to distribute in each pipeline stage. If the
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
# are not set, we will not enable uneven pipeline. All layers will be treated
# as middle layers.
num_layers_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage
)
num_layers_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage
)
middle_num_layers = (
config.num_layers
- num_layers_in_first_pipeline_stage
- num_layers_in_last_pipeline_stage
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
# num_layers_in_last_pipeline_stage are not set, all pipeline stages
# will be treated as middle pipeline stages in the calculation
num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage // vp_size
)
num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage // vp_size
)
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = (
middle_num_layers // vp_size
)
# First stage + middle stage + last stage
total_virtual_chunks = (
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
+ num_layers_per_virtual_model_chunk_in_last_pipeline_stage
)
# Calculate the layer offset with interleaved uneven pipeline parallelism
if pipeline_rank == 0:
offset = vp_rank * total_virtual_chunks
else:
offset = (
vp_rank * total_virtual_chunks
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ (pipeline_rank - 1)
* (
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
// middle_pipeline_stages
)
)
else:
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
else:
num_layers_per_pipeline_rank = 0
middle_pipeline_rank = (
pipeline_rank
if config.num_layers_in_first_pipeline_stage is None
else pipeline_rank - 1
)
if pipeline_rank == 0:
offset = 0
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + num_layers_in_first_pipeline_stage
else:
num_layers = config.num_layers
# Increase the number of layers by one if we include the embedding (loss)
# layer into pipeline parallelism partition and placement
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (
pipeline_rank * num_layers_per_virtual_rank
)
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
):
offset -= 1
else:
offset = pipeline_rank * num_layers_per_pipeline_rank
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
):
offset -= 1
else:
offset = 0
return offset
@dataclass @dataclass
class TransformerLayerSubmodules: class TransformerLayerSubmodules:
""" """
...@@ -93,14 +247,16 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -93,14 +247,16 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
): ):
super().__init__(config=config) super().__init__(config=config)
if config.enable_cuda_graph and self.training: if config.enable_cuda_graph:
assert ( if not self.training:
not config.cpu_offloading and config.recompute_granularity is None # Cudagraphs for inference are only enabled with the flash decoding kernel
), "Cudagraphs not supported" assert (
self.cudagraph_manager = CudaGraphManager() self.config.flash_decode
), "--flash-decode is required to use CUDA graphs during inference"
self.cudagraph_manager = CudaGraphManager(config)
self.submodules_config = submodules self.submodules_config = submodules
self.layer_number = layer_number + TransformerLayer._get_layer_offset(self.config) self.layer_number = layer_number + get_transformer_layer_offset(self.config)
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
# [Module 1: Input Layernorm] Optional Layernorm on the input data # [Module 1: Input Layernorm] Optional Layernorm on the input data
...@@ -174,82 +330,17 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -174,82 +330,17 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
@staticmethod @staticmethod
def _get_layer_offset(config: TransformerConfig): def _get_layer_offset(config: TransformerConfig):
"""Get the index offset of current pipeline stage, given the level of pipelining.""" """
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() Get the layer offset for the current pipeline stage.
if not parallel_state.is_inside_encoder():
pp_decoder_start = parallel_state.get_pipeline_model_parallel_decoder_start()
if pp_decoder_start is not None:
pipeline_rank = pipeline_rank - pp_decoder_start
num_layers_per_pipeline_rank = config.num_layers // config.pipeline_model_parallel_size
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if config.pipeline_model_parallel_size > 1:
if (
config.first_pipeline_num_layers is not None
or config.last_pipeline_num_layers is not None
):
# Calculate number of pipelines for distributing layers
middle_pipeline_stages = config.pipeline_model_parallel_size
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
config.first_pipeline_num_layers,
config.last_pipeline_num_layers,
)
]
)
# Calculate layers to distribute
first_pipeline_offset = (
0
if config.first_pipeline_num_layers is None
else config.first_pipeline_num_layers
)
last_pipeline_offset = (
0
if config.last_pipeline_num_layers is None
else config.last_pipeline_num_layers
)
middle_num_layers = (
config.num_layers - first_pipeline_offset - last_pipeline_offset
)
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
else:
num_layers_per_pipeline_rank = 0
middle_pipeline_rank = (
pipeline_rank
if config.first_pipeline_num_layers is None
else pipeline_rank - 1
)
if pipeline_rank == 0: Deprecated: please use `get_transformer_layer_offset` instead.
offset = 0 """
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + first_pipeline_offset
else:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset warnings.warn(
"TransformerLayer._get_layer_offset is deprecated."
"Please use get_transformer_layer_offset instead."
)
return get_transformer_layer_offset(config)
def forward( def forward(
self, self,
...@@ -263,6 +354,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -263,6 +354,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
attention_bias=None, attention_bias=None,
inference_params=None, inference_params=None,
packed_seq_params=None, packed_seq_params=None,
sequence_len_offset=None,
): ):
""" """
Perform a forward pass through the transformer layer. Perform a forward pass through the transformer layer.
...@@ -304,6 +396,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -304,6 +396,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
rotary_pos_sin=rotary_pos_sin, rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias, attention_bias=attention_bias,
packed_seq_params=packed_seq_params, packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
) )
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
...@@ -392,6 +485,18 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -392,6 +485,18 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
return sharded_state_dict return sharded_state_dict
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if hasattr(self, 'cudagraph_manager'): # Training and validation mode CUDA graphs
if hasattr(self, 'cudagraph_manager') and kwargs.get('inference_params') is None:
return self.cudagraph_manager(self, args, kwargs)
# Inference mode. CUDA graphs are used in the decode phase only, when attn mask is None
elif (
not self.training
and hasattr(self, 'cudagraph_manager')
and kwargs.get('inference_params') is not None
and kwargs['inference_params'].decode_mode
):
assert (
kwargs.get('attention_mask') is None
), f"Attention mask must not be set when using CUDA graphs for decode"
return self.cudagraph_manager(self, args, kwargs) return self.cudagraph_manager(self, args, kwargs)
return super(MegatronModule, self).__call__(*args, **kwargs) return super(MegatronModule, self).__call__(*args, **kwargs)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Utilities for transformer layers.""" """Utilities for transformer layers."""
from functools import lru_cache from functools import lru_cache
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
...@@ -32,6 +32,7 @@ def get_default_causal_mask(sq: int) -> torch.Tensor: ...@@ -32,6 +32,7 @@ def get_default_causal_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
# pylint: disable=missing-function-docstring
def attention_mask_func(attention_scores, attention_mask): def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores return attention_scores
...@@ -43,11 +44,14 @@ def gelu_impl(x): ...@@ -43,11 +44,14 @@ def gelu_impl(x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
# pylint: disable=missing-function-docstring
def openai_gelu(x): def openai_gelu(x):
return gelu_impl(x) return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter # This is actually Python equivalent of torch.nn.functional.gelu(), also with
# type hints for ONNX exporter
# pylint: disable=missing-function-docstring
@jit_fuser @jit_fuser
def erf_gelu(x): def erf_gelu(x):
return ( return (
...@@ -125,6 +129,9 @@ def make_sharded_object_for_checkpoint( ...@@ -125,6 +129,9 @@ def make_sharded_object_for_checkpoint(
ShardedObject ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]): replica id replica_id (Union[None, int, Tuple[int, ...]]): replica id
""" """
is_obj_fully_sharded = hasattr(obj, 'fully_shard_param_local_index')
assert not is_obj_fully_sharded, f"Fully sharded object not supported: {key}"
if replica_id is None: if replica_id is None:
replica_id = ( replica_id = (
0, 0,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Utility functions used throughout Megatron core""" """Utility functions used throughout Megatron core"""
import array import array
import functools
import hashlib import hashlib
import logging import logging
import math import math
...@@ -24,6 +25,7 @@ from packaging.version import Version as PkgVersion ...@@ -24,6 +25,7 @@ from packaging.version import Version as PkgVersion
try: try:
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
HAVE_DTENSOR = True HAVE_DTENSOR = True
except ImportError: except ImportError:
...@@ -268,21 +270,14 @@ def safely_set_viewless_tensor_data(tensor, new_data_tensor): ...@@ -268,21 +270,14 @@ def safely_set_viewless_tensor_data(tensor, new_data_tensor):
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
return functools.partial(torch.nn.init.normal_, mean=0.0, std=sigma)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers): def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" """Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers) std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor): return functools.partial(torch.nn.init.normal_, mean=0.0, std=std)
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
...@@ -369,16 +364,13 @@ def check_param_hashes_across_dp_replicas( ...@@ -369,16 +364,13 @@ def check_param_hashes_across_dp_replicas(
for params, local_param_hashes, all_gather_group in zip( for params, local_param_hashes, all_gather_group in zip(
[non_expert_params, expert_params], [non_expert_params, expert_params],
[local_non_expert_param_hashes, local_expert_param_hashes], [local_non_expert_param_hashes, local_expert_param_hashes],
[ [parallel_state.get_data_parallel_group(), parallel_state.get_expert_data_parallel_group()],
parallel_state.get_data_parallel_group_gloo(),
parallel_state.get_expert_data_parallel_group_gloo(),
],
): ):
# Collect per-parameter hashes across all ranks in group. # Collect per-parameter hashes across all ranks in group.
assert len(params) == len(local_param_hashes) assert len(params) == len(local_param_hashes)
if len(params) == 0: if len(params) == 0:
continue continue
local_param_hashes = torch.stack(local_param_hashes) local_param_hashes = torch.stack(local_param_hashes).cuda()
all_param_hashes = [ all_param_hashes = [
torch.zeros_like(local_param_hashes) torch.zeros_like(local_param_hashes)
for _ in range(torch.distributed.get_world_size(all_gather_group)) for _ in range(torch.distributed.get_world_size(all_gather_group))
...@@ -442,6 +434,28 @@ def make_tp_sharded_tensor_for_checkpoint( ...@@ -442,6 +434,28 @@ def make_tp_sharded_tensor_for_checkpoint(
if replica_id is None: if replica_id is None:
replica_id = (0, 0, dp_replica_id) replica_id = (0, 0, dp_replica_id)
if hasattr(tensor, 'fully_shard_param_local_shard'):
assert len(replica_id) == 3, f'Expected replica_id format (PP, TP, DP), got: {replica_id}'
replica_id = (*replica_id[:2], 0)
sh_ten = ShardedTensor.from_rank_offsets_flat(
key,
tensor.fully_shard_param_local_shard,
tensor.shape,
*prepend_offsets,
(
tp_axis + prepend_axis_num,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
flattened_range=slice(*tensor.fully_shard_param_local_index),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
setattr(sh_ten, 'is_data_parallel_fully_shard', True)
return sh_ten
return ShardedTensor.from_rank_offsets( return ShardedTensor.from_rank_offsets(
key, key,
tensor, tensor,
...@@ -469,12 +483,29 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ ...@@ -469,12 +483,29 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_
if HAVE_DTENSOR and isinstance(tensor, DTensor): if HAVE_DTENSOR and isinstance(tensor, DTensor):
# FSDP2 sharding # FSDP2 sharding
dp_replica_id = 0 dp_replica_id = 0
tensor = tensor._local_tensor tensor = get_full_tensor_if_necessary(tensor)
new_offsets.append((prepend_axis_num, dp_rank, dp_size)) new_offsets.append((prepend_axis_num, dp_rank, dp_size))
if replica_id is None: if replica_id is None:
replica_id = (0, parallel_state.get_tensor_model_parallel_rank(), dp_replica_id) replica_id = (0, parallel_state.get_tensor_model_parallel_rank(), dp_replica_id)
if hasattr(tensor, 'fully_shard_param_local_shard'):
assert len(replica_id) == 3, f'Expected replica_id format (PP, TP, DP), got: {replica_id}'
replica_id = (*replica_id[:2], 0)
sh_ten = ShardedTensor.from_rank_offsets_flat(
key,
tensor.fully_shard_param_local_shard,
tensor.shape,
*prepend_offsets,
flattened_range=slice(*tensor.fully_shard_param_local_index),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
setattr(sh_ten, 'is_data_parallel_fully_shard', True)
return sh_ten
return ShardedTensor.from_rank_offsets( return ShardedTensor.from_rank_offsets(
key, key,
tensor, tensor,
...@@ -486,6 +517,22 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ ...@@ -486,6 +517,22 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_
) )
def get_full_tensor_if_necessary(tensor):
"""For DTensor gets full tensor if some ranks will not have a local copy"""
need_full_tensor = False
for i in range(tensor.device_mesh.ndim):
if (
isinstance(tensor.placements[i], Shard)
and tensor.device_mesh.shape[i] > tensor.shape[tensor.placements[i].dim]
):
need_full_tensor = True
break
tensor = tensor.full_tensor() if need_full_tensor else tensor._local_tensor
return tensor
def to_local_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: def to_local_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""Returns the local shard of the given tensor if it is a DTensor.""" """Returns the local shard of the given tensor if it is a DTensor."""
with torch.no_grad(): with torch.no_grad():
...@@ -1399,17 +1446,52 @@ __straggler__ = StragglerDetector() ...@@ -1399,17 +1446,52 @@ __straggler__ = StragglerDetector()
""" """
# Check if Transformer Engine has Float8Tensor class def is_submodule(module, parent_module, strict=True):
HAVE_TE_FLOAT8TENSOR = False """
try: Check if a module is a submodule of another module.
from transformer_engine.pytorch.float8_tensor import Float8Tensor """
if strict:
if module is parent_module:
return False
for m in parent_module.modules():
if m is module:
return True
return False
HAVE_TE_FLOAT8TENSOR = True ########################
except (ImportError, ModuleNotFoundError): ### context parallel ###
# Float8Tensor not found ########################
pass
def is_float8tensor(tensor: torch.Tensor) -> bool: def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
"""Check if a tensor is a Transformer Engine Float8Tensor""" """Slice batch input along sequence dimension into multiple chunks,
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -6,11 +6,10 @@ from collections.abc import Iterable ...@@ -6,11 +6,10 @@ from collections.abc import Iterable
import torch import torch
from megatron.core import InferenceParams, mpu
from megatron.training import get_args from megatron.training import get_args
from megatron.core import mpu, InferenceParams
from .communication import ( from .communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class ForwardStep: class ForwardStep:
...@@ -46,7 +45,7 @@ class ForwardStep: ...@@ -46,7 +45,7 @@ class ForwardStep:
# This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold # This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold
# and requires setting args.pipeline_model_parallel > 1. The batch will be split into # and requires setting args.pipeline_model_parallel > 1. The batch will be split into
# smaller microbatches to be pipelined through the stages. # smaller microbatches to be pipelined through the stages.
if self.pipeline_size_larger_than_one: if self.pipeline_size_larger_than_one and self.pipelining_batch_x_seqlen != -1:
seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length
current_batch_x_seqlen = tokens.size(0) * seq_len current_batch_x_seqlen = tokens.size(0) * seq_len
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from megatron.training import get_args, get_tokenizer from megatron.training import get_args, get_tokenizer
from megatron.core import mpu from megatron.core import mpu
from megatron.training.utils import get_ltor_masks_and_position_ids from megatron.training.utils import get_ltor_masks_and_position_ids
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from .communication import ( from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage, broadcast_from_last_pipeline_stage,
...@@ -86,7 +87,7 @@ def score_and_return_on_first_stage(model, tokens: torch.Tensor, lengths: torch. ...@@ -86,7 +87,7 @@ def score_and_return_on_first_stage(model, tokens: torch.Tensor, lengths: torch.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Always the last stage should have an output. # Always the last stage should have an output.
assert logits is not None assert logits is not None
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2).to(dtype=output_topk_log_probs.dtype)
# Pick the tokens that we need to get the log # Pick the tokens that we need to get the log
# probabilities for. Note that next input token is # probabilities for. Note that next input token is
...@@ -202,7 +203,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -202,7 +203,7 @@ def generate_tokens_probs_and_return_on_first_stage(
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# ============= # =============
# Run infernece # Run inference
# ============= # =============
with torch.no_grad(): with torch.no_grad():
...@@ -211,15 +212,24 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -211,15 +212,24 @@ def generate_tokens_probs_and_return_on_first_stage(
prev_context_length = 0 prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length): for context_length in range(min_prompt_length, max_sequence_length):
prefill = context_length == min_prompt_length
if not prefill:
forward_step.inference_params.enable_decode_mode()
# Pick the slice that we need to pass through the network. # Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length] tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length]
# Do not pass a variable-shape attention mask in the decode phase.
attention_mask2use = attention_mask[ attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length] ..., prev_context_length:context_length, :context_length] if prefill else None
# logits will be meanigful only in the last pipeline stage. # logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
if args.enable_cuda_graph:
create_cudagraphs()
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon: if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
...@@ -343,7 +353,7 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, ...@@ -343,7 +353,7 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths,
device=torch.cuda.current_device()).unsqueeze(1) device=torch.cuda.current_device()).unsqueeze(1)
scores_size_tensor, tokens_size_tensor = None, None scores_size_tensor, tokens_size_tensor = None, None
# ============= # =============
# Run infernece # Run inference
# ============= # =============
with torch.no_grad(): with torch.no_grad():
tokens = tokens.repeat(beam_size, 1) tokens = tokens.repeat(beam_size, 1)
...@@ -351,11 +361,15 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, ...@@ -351,11 +361,15 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths,
prev_context_length = 0 prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length): for context_length in range(prompt_length, final_sequence_length):
prefill = context_length == prompt_length
# Pick the slice that we need to pass through the network. # Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length] tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length]
# Do not pass a variable-shape attention mask in the decode phase.
attention_mask2use = attention_mask[ attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length] ..., prev_context_length:context_length, :context_length] if not prefill else None
# logits will be meanigful only in the last pipeline stage. # logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
......
File mode changed from 100755 to 100644
...@@ -22,6 +22,7 @@ def detokenize_generations(tokens_gpu_tensor, ...@@ -22,6 +22,7 @@ def detokenize_generations(tokens_gpu_tensor,
tokens = tokens_gpu_tensor.cpu().numpy().tolist() tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist() lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths): for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length] sequence_tokens = sequence_tokens[:length]
detok_str = tokenizer.detokenize(sequence_tokens) detok_str = tokenizer.detokenize(sequence_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment