Commit 4e2de453 authored by dongcl's avatar dongcl
Browse files

megatron patch

parent d77d95c5
# coding=utf-8 import os
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc import abc
import sys import sys
import types import types
...@@ -38,15 +24,15 @@ class MegatronAdaptation: ...@@ -38,15 +24,15 @@ class MegatronAdaptation:
# MegatronAdaptation.post_execute() # MegatronAdaptation.post_execute()
@classmethod @classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False): def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False):
""" """
Register adaptations into collection. Register adaptations into collection.
""" """
if orig_func_name not in cls._patch_info_collection: if orig_func_name not in cls._patch_info_collection:
from .patch_utils import Patch from .patch_utils import Patch
cls._patch_info_collection[orig_func_name] = Patch(orig_func_name, new_func, create_dummy) cls._patch_info_collection[orig_func_name] = Patch(orig_func_name, new_func, create_dummy, apply_wrapper=apply_wrapper)
else: else:
cls._patch_info_collection.get(orig_func_name).set_patch_func(new_func, force_patch) cls._patch_info_collection.get(orig_func_name).set_patch_func(new_func, force_patch, apply_wrapper=apply_wrapper)
@classmethod @classmethod
def apply(cls): def apply(cls):
...@@ -138,24 +124,50 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -138,24 +124,50 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig', MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig) MLATransformerConfig)
# Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
def patch_core_extentions(self): def patch_core_extentions(self):
from ..core.extensions.transformer_engine import te_dot_product_attention_init import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init, TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init) te_dot_product_attention_init)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward) vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init) vocab_parallel_embedding_init)
# _VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies from ..training.initialize import _compile_dependencies
from ..training.training import train
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer', MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer) build_tokenizer)
...@@ -164,6 +176,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -164,6 +176,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies', MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies) _compile_dependencies)
# traing.train
MegatronAdaptation.register('megatron.training.training.train',
train)
def patch_miscellaneous(self): def patch_miscellaneous(self):
from ..training.arguments import parse_args from ..training.arguments import parse_args
...@@ -176,7 +192,22 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -176,7 +192,22 @@ class LegacyAdaptation(MegatronAdaptationABC):
""" """
def execute(self): def execute(self):
pass self.patch_legacy_models()
def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLP, ParallelAttention
# ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
ParallelMLP.__init__)
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.forward',
ParallelAttention.forward)
# rms_norm.RMSNorm
MegatronAdaptation.register('megatron.legacy.model.rms_norm.RMSNorm.forward',
torch.compile(mode="max-autotune-no-cudagraphs"),
apply_wrapper=True)
MegatronAdaptation.execute() MegatronAdaptation.execute()
...@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name): ...@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name):
class Patch: class Patch:
def __init__(self, orig_func_name, new_func, create_dummy): def __init__(self, orig_func_name, new_func, create_dummy, apply_wrapper=False):
split_name = orig_func_name.rsplit('.', 1) split_name = orig_func_name.rsplit('.', 1)
if len(split_name) == 1: if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None self.orig_module_name, self.orig_func_name = orig_func_name, None
...@@ -30,7 +30,7 @@ class Patch: ...@@ -30,7 +30,7 @@ class Patch:
self.wrappers = [] self.wrappers = []
if new_func is None: if new_func is None:
new_func = dummy_function_wrapper(orig_func_name) new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func) self.set_patch_func(new_func, apply_wrapper=apply_wrapper)
self.is_applied = False self.is_applied = False
self.create_dummy = create_dummy self.create_dummy = create_dummy
...@@ -42,8 +42,11 @@ class Patch: ...@@ -42,8 +42,11 @@ class Patch:
def patch_func_id(self): def patch_func_id(self):
return id(self.patch_func) return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False): def set_patch_func(self, new_func, force_patch=False, apply_wrapper=False):
if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')): if (
apply_wrapper
or (hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')))
):
self.wrappers.append(new_func) self.wrappers.append(new_func)
else: else:
if self.patch_func and not force_patch: if self.patch_func and not force_patch:
......
import torch
import torch.nn.functional as F
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.utils import (
erf_gelu,
openai_gelu,
)
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config, is_expert=False):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
is_expert=is_expert,
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
@torch.compile(mode="max-autotune-no-cudagraphs")
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
skip_bias_add=True,
input_is_parallel=True,
is_expert=is_expert,
)
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer,
key_layer,
value_layer) = torch.split(
mixed_x_layer,
[
(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key_layer = key_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if not self.use_flash_attn:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
else:
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
...@@ -4,7 +4,6 @@ import argparse ...@@ -4,7 +4,6 @@ import argparse
from megatron.training.arguments import ( from megatron.training.arguments import (
_add_network_size_args, _add_network_size_args,
_add_regularization_args, _add_regularization_args,
_add_training_args,
_add_initialization_args, _add_initialization_args,
_add_learning_rate_args, _add_learning_rate_args,
_add_checkpointing_args, _add_checkpointing_args,
...@@ -249,6 +248,8 @@ def _add_tokenizer_args(parser): ...@@ -249,6 +248,8 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer', 'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer', 'HuggingFaceTokenizer',
'Llama2Tokenizer', 'Llama2Tokenizer',
'Llama3Tokenizer',
'QwenTokenizer',
'TikTokenizer', 'TikTokenizer',
'MultimodalTokenizer', 'MultimodalTokenizer',
'NullTokenizer', 'NullTokenizer',
...@@ -265,6 +266,255 @@ def _add_tokenizer_args(parser): ...@@ -265,6 +266,255 @@ def _add_tokenizer_args(parser):
return parser return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \\ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--decrease-batch-size-if-needed', action='store_true', default=False,
help='If set, decrease batch size if microbatch_size * dp_size'
'does not divide batch_size. Useful for KSO (Keep Soldiering On)'
'to continue making progress if number of healthy GPUs (and'
'corresponding dp_size) does not support current batch_size.'
'Old batch_size will be restored if training is re-started with'
'dp_size that divides batch_size // microbatch_size.')
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--no-check-for-nan-in-loss-and-grad', action='store_false',
help='Check for NaNs in loss and grad',
dest='check_for_nan_in_loss_and_grad')
group.add_argument('--check-for-spiky-loss', action='store_true',
help='Check for spiky loss',
dest='check_for_spiky_loss')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=None,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
dest='clone_scatter_output_in_embedding')
group.add_argument('--profile', action='store_true',
help='Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
'nsys commandline is `nsys profile -s none -t nvtx,cuda '
'-o <path/to/output_file> --force-overwrite true '
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.')
group.add_argument('--profile-step-start', type=int, default=10,
help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Global step to stop profiling.')
group.add_argument('--use-pytorch-profiler', action='store_true',
help='Use the built-in pytorch profiler. '
'Useful if you wish to view profiles in tensorboard.',
dest='use_pytorch_profiler')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
group.add_argument('--record-memory-history', action="store_true", default=False,
help='Record memory history in last rank.')
group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle",
help='Specifies where to dump the memory history pickle.')
group.add_argument('--tp-comm-overlap', action='store_true', help='Enables the '
' overlap of Tensor parallel communication and GEMM kernels.')
group.add_argument('--tp-comm-overlap-cfg', type=str, default=None,
help='Config file when tp_comm_overlap is enabled.')
group.add_argument('--disable-tp-comm-overlap-ag', action='store_false',
help=('Disables the All-Gather overlap with GEMM by '
'pipelining the GEMM and All-Gather.'),
dest='tp_comm_overlap_ag')
group.add_argument('--disable-tp-comm-overlap-rs', action='store_false',
help=('Disables the Reduce-Scatter overlap with GEMM by '
'pipelining the GEMM and Reduce-Scatter.'),
dest='tp_comm_overlap_rs')
group.add_argument('--tp-comm-overlap-rs-dgrad', action='store_true',
help = 'Enables the Reduce-Scatter overlap with dgrad GEMM.',
dest='tp_comm_overlap_rs_dgrad')
group.add_argument('--disable-tp-comm-bulk-dgrad', action='store_false',
help='Disables the All-Gather overlap with bprop activation gradient GEMM.',
dest='tp_comm_bulk_dgrad')
group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false',
help='Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.',
dest='tp_comm_bulk_wgrad')
group.add_argument('--tp-comm-bootstrap-backend', default='nccl', type=str,
choices=['nccl', 'mpi', 'gloo'],
help='Set the bootstrapping backend of Tensor parallel communications.')
group.add_argument('--use-cpu-initialization', action='store_true',
default=None,
help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.')
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--deterministic-mode', action='store_true',
help='Choose code that has deterministic execution. This usually '
'means slower execution, but is good for debugging and testing.')
group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None,
help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.')
group.add_argument('--calculate-per-token-loss', action='store_true',
help=('Scale cross entropy loss by the number of non-padded tokens in the '
'global batch, versus the default behavior of assuming all tokens are non-padded.'))
group.add_argument('--train-sync-interval', type=int, default=None,
help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
group.add_argument('--no-bias-swiglu-fusion', action='store_false',
help='Disable bias and swiglu fusion, the fusion is '
'available only when using megatron-core.',
dest='bias_swiglu_fusion')
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--no-rope-fusion', action='store_false',
help='Disable rope fusion, the fusion is available '
'only when using megatron-core.',
dest='apply_rope_fusion')
group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--add-qkv-bias', action='store_true',
help='Enable bias only in the QKV linear layers',
dest='add_qkv_bias')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic', 'external'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_false',
help='DEPRECATED. This flag is ignored.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
group.add_argument('--use-mcore-models', action='store_true',
dest='deprecated_use_mcore_models',
help='DEPRECATED. Use the implementation from megatron core.'
'Now ignored and mcore models are the default, use '
'--use-legacy-models to not use core models.')
group.add_argument('--use-legacy-models', action='store_true',
help='Use the legacy Megatron models, not Megatron-Core models.')
group.add_argument('--manual-gc', action='store_true',
help='Disable the threshold-based default garbage '
'collector and trigger the garbage collection manually. '
'Manual garbage collection helps to align the timing of '
'the collection across ranks which mitigates the impact '
'of CPU-associated jitters. When the manual gc is enabled, '
'garbage collection is performed only at the start and the '
'end of the validation routine by default.')
group.add_argument('--manual-gc-interval', type=int, default=0,
help='Training step interval to trigger manual garbage '
'collection. When the value is set to 0, garbage '
'collection is not triggered between training steps.')
group.add_argument('--no-manual-gc-eval', action='store_false',
help='When using manual garbage collection, disable '
'garbage collection at the start and the end of each '
'evaluation run.', dest='manual_gc_eval')
group.add_argument('--disable-tp-comm-split-ag', action='store_false',
help='Disables the All-Gather overlap with fprop GEMM.',
dest='tp_comm_split_ag')
group.add_argument('--disable-tp-comm-split-rs', action='store_false',
help='Disables the Reduce-Scatter overlap with fprop GEMM.',
dest='tp_comm_split_rs')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
return parser
def _add_mtp_args(parser): def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction') group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num') group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
......
from transformers import AutoTokenizer from transformers import AutoTokenizer, Qwen2Tokenizer
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron.training.tokenizer.tokenizer import ( from megatron.training.tokenizer.tokenizer import (
_BertWordPieceTokenizer, _BertWordPieceTokenizer,
...@@ -46,6 +46,11 @@ def build_tokenizer(args, **kwargs): ...@@ -46,6 +46,11 @@ def build_tokenizer(args, **kwargs):
elif args.tokenizer_type == 'Llama2Tokenizer': elif args.tokenizer_type == 'Llama2Tokenizer':
assert args.tokenizer_model is not None assert args.tokenizer_model is not None
tokenizer = _Llama2Tokenizer(args.tokenizer_model) tokenizer = _Llama2Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'Llama3Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama3Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'QwenTokenizer':
tokenizer = _Qwen2Tokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'TikTokenizer': elif args.tokenizer_type == 'TikTokenizer':
assert args.tokenizer_model is not None assert args.tokenizer_model is not None
assert args.tiktoken_pattern is not None assert args.tiktoken_pattern is not None
...@@ -101,6 +106,96 @@ def build_tokenizer(args, **kwargs): ...@@ -101,6 +106,96 @@ def build_tokenizer(args, **kwargs):
return tokenizer return tokenizer
class _Llama3Tokenizer(MegatronTokenizer):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def __init__(self, model_file):
super().__init__(model_file)
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
tokenizer_path=model_file
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range (5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
self.tokenizer = tiktoken.Encoding(tokenizer_path,
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len (mergeable_ranks) + i for i, token in enumerate (special_tokens)},
)
self.eod_id = self.tokenizer.encode("<|end_of_text|>", allowed_special="all")[0]
@property
def vocab_size(self):
return self.tokenizer.n_vocab
@property
def vocab(self):
return self.tokenizer.encode
@property
def inv_vocab(self):
return self.tokenizer.encode
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.encode(token_ids)
@property
def eod(self):
return self.eod_id
class _Qwen2Tokenizer(MegatronTokenizer):
def __init__(self, vocab_file, merge_file,extra_vocab_size=0):
super().__init__(vocab_file, merge_file)
self.tokenizer = Qwen2Tokenizer(vocab_file, merge_file)
self.extra_vocab_size = extra_vocab_size
self.tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token="<|extra_0|>"))
@property
def vocab_size(self):
return len(self.tokenizer.encoder) + self.extra_vocab_size
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.tokenizer.eos_token_id
@property
def eos_token(self):
return self.tokenizer.eos_token
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
class _DeepSeekV2Tokenizer(MegatronTokenizer): class _DeepSeekV2Tokenizer(MegatronTokenizer):
def __init__(self, tokenizer_path, extra_vocab_size): def __init__(self, tokenizer_path, extra_vocab_size):
super().__init__(tokenizer_path) super().__init__(tokenizer_path)
......
import gc
import sys
import torch.distributed
import torch
from megatron.core import mpu
from megatron.core.utils import (
check_param_hashes_across_dp_replicas,
StragglerDetector,
)
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import finalize_model_grads
from megatron.training.initialize import write_args_to_tensorboard
from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size,
get_current_running_global_batch_size,
get_num_microbatches,
update_num_microbatches)
from megatron.training.async_utils import maybe_finalize_async_save
from megatron.training.utils import (
calc_params_l2_norm,
print_rank_0,
)
from megatron.training.global_vars import (
get_args,
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger,
)
from megatron.training import one_logger_utils
from megatron.training import ft_integration
from megatron.training.training import (
print_datetime,
disable_forward_pre_hook,
train_step,
save_checkpoint_and_time,
enable_forward_pre_hook,
num_floating_point_operations,
training_log,
evaluate_and_print_results,
post_training_step_callbacks,
checkpoint_and_decide_exit,
)
stimer = StragglerDetector()
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
one_logger = get_one_logger()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
# Setup some training config params.
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
if args.align_grad_reduce:
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
if len(model) == 1:
config.grad_sync_func = config.grad_sync_func[0]
if args.overlap_param_gather and args.align_param_gather:
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
if len(model) == 1:
config.param_sync_func = config.param_sync_func[0]
config.finalize_model_grads_func = finalize_model_grads
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
pre_hook_enabled = False
should_exit = False
exit_code = 0
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()
# Singleton initialization of straggler detector.
if args.log_straggler:
global stimer
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(world, rank,
mmcnt = mmcnt,
enabled = not args.disable_straggler_on_startup,
port = args.straggler_ctrlr_port)
num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches()
eval_duration = 0.0
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration,
'eval_iterations': eval_iterations,
'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length
}
# Cache into one-logger for callback.
if one_logger:
with one_logger.get_context_manager():
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
if args.rank in [0]:
print(p.key_averages(group_by_input_shape=True,
group_by_stack_n=5).table(sort_by="self_cuda_time_total",
row_limit=-1,
max_src_column_width=100,
max_name_column_width=280,
max_shapes_column_width=200))
p.export_chrome_trace("{path}/trace_rank{rank}_step{step}.json".format(
path=args.profile_dir, rank=torch.distributed.get_rank(), step=p.step_num))
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
record_shapes=True,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready=trace_handler)
prof.start()
start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.step()
elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=False)
ft_integration.on_checkpointing_end(is_async_finalization=True)
# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
# from the previous iteration, save a checkpoint. Then run consistency check
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
assert num_skipped_samples_in_batch == 0
args.skipped_train_samples += num_skipped_samples_in_batch
num_floating_point_operations_in_batch = num_floating_point_operations(args, batch_size)
num_floating_point_operations_so_far += num_floating_point_operations_in_batch
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
# Logging.
if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
decoupled_learning_rate = None
for param_group in optimizer.param_groups:
if param_group['is_decoupled_lr']:
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
one_logger_utils.track_e2e_metrics()
if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
one_logger_utils.track_e2e_metrics()
# Flush TensorBoard, WandB writers and one-logger.
writer = get_tensorboard_writer()
if writer:
writer.flush()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if pre_hook_enabled:
disable_forward_pre_hook(model)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment