Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
...@@ -576,7 +576,10 @@ class TransformerBlock(MegatronModule): ...@@ -576,7 +576,10 @@ class TransformerBlock(MegatronModule):
non_homogeneous_layers = metadata is not None and metadata.get( non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False 'non_homogeneous_layers', False
) )
if self.config.num_moe_experts is not None: if isinstance(self.config.moe_layer_freq, int):
if self.config.moe_layer_freq > 1:
non_homogeneous_layers = True
elif isinstance(self.config.moe_layer_freq, list):
non_homogeneous_layers = True non_homogeneous_layers = True
sharded_state_dict = {} sharded_state_dict = {}
......
...@@ -266,23 +266,37 @@ class TransformerConfig(ModelParallelConfig): ...@@ -266,23 +266,37 @@ class TransformerConfig(ModelParallelConfig):
"""MoE Feed-Forward Network hidden size""" """MoE Feed-Forward Network hidden size"""
moe_router_load_balancing_type: str = "aux_loss" moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load """The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the loss used in DeepSeekV2,
algorithm used in S-BASE, and "none" implies no load balancing.""" which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss"."""
moe_router_topk: int = 2 moe_router_topk: int = 2
"""Number of experts to route to for each token.""" """Number of experts to route to for each token."""
moe_router_topk_limited_devices: int = None
"""Number of expert parallel ranks to consider for each token during routing. Perform top-k
routing on a subset of expert parallel ranks by first selecting N ranks for each token, then
conducting top-k selection among experts on these devices. None means no device limitation."""
moe_router_pre_softmax: bool = False moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. """Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
By default, softmax is done after top-k.""" By default, softmax is done after top-k."""
moe_router_topk_scaling_factor: float = None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_grouped_gemm: bool = False moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms """When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
""" """
moe_use_legacy_grouped_gemm: bool = False
"""Use legacy GroupedMLP rather than TEGroupedMLP.
Note: The legacy one will be deprecated soon."""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss. moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.""" """Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
...@@ -354,6 +368,11 @@ class TransformerConfig(ModelParallelConfig): ...@@ -354,6 +368,11 @@ class TransformerConfig(ModelParallelConfig):
enable_cuda_graph: bool = False enable_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with a CUDA graphed version.""" """When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""
cuda_graph_retain_backward_graph: bool = False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""
external_cuda_graph: bool = False external_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs.""" """When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
...@@ -428,7 +447,7 @@ class TransformerConfig(ModelParallelConfig): ...@@ -428,7 +447,7 @@ class TransformerConfig(ModelParallelConfig):
) )
if self.moe_expert_capacity_factor < 0: if self.moe_expert_capacity_factor < 0:
self.moe_expert_capacity_factor = None self.moe_expert_capacity_factor = None
if self.moe_router_load_balancing_type not in ["aux_loss", "none"]: if self.moe_router_load_balancing_type not in ["aux_loss", "seq_aux_loss", "none"]:
raise ValueError( raise ValueError(
'moe_expert_capacity_factor only works with aux_loss or none load balancing' 'moe_expert_capacity_factor only works with aux_loss or none load balancing'
) )
...@@ -529,9 +548,12 @@ class TransformerConfig(ModelParallelConfig): ...@@ -529,9 +548,12 @@ class TransformerConfig(ModelParallelConfig):
if self.rotary_interleaved: if self.rotary_interleaved:
raise ValueError("rotary_interleaved does not work with apply_rope_fusion.") raise ValueError("rotary_interleaved does not work with apply_rope_fusion.")
from megatron.core.models.common.embeddings.rope_utils import HAVE_APPLY_ROPE_FUSION from megatron.core.models.common.embeddings.rope_utils import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
if not HAVE_APPLY_ROPE_FUSION: if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None:
raise ValueError( raise ValueError(
"apply_rope_fusion is not available. Please install TE >= 1.4 or Apex." "apply_rope_fusion is not available. Please install TE >= 1.4 or Apex."
) )
...@@ -569,6 +591,14 @@ class TransformerConfig(ModelParallelConfig): ...@@ -569,6 +591,14 @@ class TransformerConfig(ModelParallelConfig):
f"but your version is {get_te_version()}." f"but your version is {get_te_version()}."
) )
if self.moe_router_topk_limited_devices:
if self.moe_router_topk_limited_devices > self.expert_model_parallel_size:
raise ValueError(
f"moe_router_topk_limited_devices: {self.moe_router_topk_limited_devices} "
f"must be smaller than expert_model_parallel_size "
f"{self.expert_model_parallel_size}"
)
if self.flash_decode and self.fp8: if self.flash_decode and self.fp8:
raise ValueError("FP8 inference is currently not support with flash decoding.") raise ValueError("FP8 inference is currently not support with flash decoding.")
......
...@@ -92,7 +92,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ...@@ -92,7 +92,7 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
hidden_dropout: float = None, hidden_dropout: float = None,
): ):
super().__init__(config=config) super().__init__(config=config)
if config.enable_cuda_graph and self.training: if config.enable_cuda_graph and self.training:
assert ( assert (
not config.cpu_offloading and config.recompute_granularity is None not config.cpu_offloading and config.recompute_granularity is None
......
File mode changed from 100755 to 100644
...@@ -1413,3 +1413,41 @@ except (ImportError, ModuleNotFoundError): ...@@ -1413,3 +1413,41 @@ except (ImportError, ModuleNotFoundError):
def is_float8tensor(tensor: torch.Tensor) -> bool: def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor""" """Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
########################
### context parallel ###
########################
def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment