Commit c63fceee authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap' into 'core_v0.12.0'

A2a overlap

See merge request OpenDAS/dcu_megatron!4
parents 6c3cfb1d bfe0b4a9
...@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod, staticmethod,
apply_wrapper=True) apply_wrapper=True)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# flux # flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")): if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel.layers import ( from ..core.tensor_parallel.layers import (
...@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies from ..training.initialize import _compile_dependencies
from ..training.training import train from ..training.training import train
from ..training.initialize import _set_random_seed
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer', MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer) build_tokenizer)
...@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies', MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies) _compile_dependencies)
# 添加固定seed
MegatronAdaptation.register('megatron.training.initialize._set_random_seed',
_set_random_seed)
# add trace_handler # add trace_handler
MegatronAdaptation.register('megatron.training.training.train', MegatronAdaptation.register('megatron.training.training.train',
train) train)
......
...@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode): ...@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
) )
return hidden_states return hidden_states
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_attention_dw()
class FakeScheduleNode: class FakeScheduleNode:
...@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode): ...@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
def forward_impl(self, hidden_states): def forward_impl(self, hidden_states):
return self.layer._submodule_dense_forward(hidden_states) return self.layer._submodule_dense_forward(hidden_states)
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_mlp_dw()
def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream): def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream):
common_state = TransformerLayerState() common_state = TransformerLayerState()
...@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream) ...@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
attn.name = "attn" attn.name = "attn"
dispatch = FakeScheduleNode() dispatch = FakeScheduleNode()
mlp = DenseMlpNode(chunk_state, common_state, layer, comp_stream, event) mlp = DenseMlpNode(chunk_state, common_state, layer, comp_stream, event)
mlp.name = "mlp"
combine = FakeScheduleNode() combine = FakeScheduleNode()
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine) return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
......
...@@ -7,6 +7,7 @@ from megatron.training import get_args ...@@ -7,6 +7,7 @@ from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import ( from megatron.core.utils import (
get_attr_wrapped_model, get_attr_wrapped_model,
...@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import ( ...@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches( def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
): ):
......
from .mappings import all_to_all
\ No newline at end of file
import torch
from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
ctx.use_qcomm = use_qcomm
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
if use_qcomm:
output = input.new_empty(
size=[input.shape[0], input.shape[1]+4],
dtype=torch.int8,
device=torch.cuda.current_device(),
)
else:
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
if use_qcomm:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8,
device=torch.cuda.current_device(),
)
else:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
if use_qcomm:
output = q_alltoall(output, input, output_split_sizes, input_split_sizes,group)
else:
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, *grad_output):
"""Backward function."""
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_qcomm),
None,
None,
None,
)
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_qcomm=False):
"""Wrapper for autograd function"""
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_qcomm)
import torch
import triton
import triton.language as tl
import random
import unittest
import json
import os
import time
@triton.jit
def _fwd_kernel_destindex_copy_quantize_kv_init_asym(
K, Out, Out_scale_zero,
stride_k_bs, stride_k_h, stride_k_d,
stride_o_bs, stride_o_h, stride_o_d,
stride_os_bs, stride_os_h, stride_os_d,
head_num,head_dim,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
dest_index = cur_index
m1 = offs_h[:, None] < head_num
m2 = offs_d[None,:] < head_dim
mask = m1&m2
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
mask=mask, other=0.0).to(tl.float32)
src_data_max = tl.max(src_data, axis=1, keep_dims=True)
src_data_min = tl.min(src_data, axis=1, keep_dims=True)
data_scale = (src_data_max - src_data_min) / 255.0
data_zero = (-1 * src_data_min / data_scale).to(tl.int32)
q_src_data = (tl.clamp((src_data / data_scale).to(tl.int32).to(tl.float32) + data_zero.to(tl.float32), 0.0, 255.0).to(tl.int32) - 128).to(tl.int8)
data_scale = data_scale.to(Out_scale_zero.dtype.element_ty)
data_zero = data_zero.to(Out_scale_zero.dtype.element_ty)
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
os_ptrs = Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
oz_ptrs = Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] + 1
tl.store(o_ptrs, q_src_data, mask=mask)
tl.store(os_ptrs, data_scale, mask=m1)
tl.store(oz_ptrs, data_zero, mask=m1)
@torch.no_grad()
def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
bs_seq = K.shape[0]
head_num = K.shape[1]
head_dim = K.shape[2]
assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]
BLOCK_HEAD = triton.next_power_of_2(head_num)
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
grid = (bs_seq,)
num_warps = 1
_fwd_kernel_destindex_copy_quantize_kv_init_asym[grid](
K, Out, Out_scale_zero,
K.stride(0), K.stride(1), K.stride(2),
Out.stride(0), Out.stride(1), Out.stride(2),
Out_scale_zero.stride(0), Out_scale_zero.stride(1), Out_scale_zero.stride(2),
head_num,head_dim,
BLOCK_DMODEL= BLOCK_DMODEL,
BLOCK_HEAD=BLOCK_HEAD,
num_warps=num_warps,
num_stages=1,
)
return
@triton.jit
def _bwd_kernel_destindex_dequantize_kv(
Quantized_Out, Out_scale_zero, Dequantized_Out,
stride_qo_bs, stride_qo_h, stride_qo_d,
stride_os_bs, stride_os_h, stride_os_d,
stride_do_bs, stride_do_h, stride_do_d,
head_num,head_dim,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
scales_dtype = Out_scale_zero.dtype.element_ty
dest_index = cur_index
m1 = offs_h[:, None] < head_num
m2 = offs_d[None,:] < head_dim
mask = m1&m2
# Load quantized data
q_data = tl.load(
Quantized_Out + dest_index * stride_qo_bs + offs_h[:, None] * stride_qo_h + stride_qo_d * offs_d[None, :],
mask=mask,
other=0
)
# Load scale and zero point
data_scale = tl.load(
Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None],
mask=m1,
other=1.0
)
data_zero = tl.load(
Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] + 1,
mask=m1,
other=0
)
# Dequantize
dequantized_data = (q_data.to(tl.int32) + 128 - data_zero.to(tl.int32)).to(scales_dtype) * data_scale
# Store dequantized data
out_ptrs = Dequantized_Out + dest_index * stride_do_bs + stride_do_h * offs_h[:, None] + stride_do_d * offs_d[None, :]
tl.store(out_ptrs, dequantized_data, mask=mask)
@torch.no_grad()
def destindex_dequantize_kv(Quantized_Out, Out_scale_zero, Dequantized_Out):
bs_seq = Quantized_Out.shape[0]
head_num = Quantized_Out.shape[1]
head_dim = Quantized_Out.shape[2]
assert Quantized_Out.shape[1] == Dequantized_Out.shape[1] and Quantized_Out.shape[2] == Dequantized_Out.shape[2]
BLOCK_HEAD = triton.next_power_of_2(head_num)
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
grid = (bs_seq,)
num_warps = 1
_bwd_kernel_destindex_dequantize_kv[grid](
Quantized_Out, Out_scale_zero, Dequantized_Out,
Quantized_Out.stride(0), Quantized_Out.stride(1), Quantized_Out.stride(2),
Out_scale_zero.stride(0), Out_scale_zero.stride(1), Out_scale_zero.stride(2),
Dequantized_Out.stride(0), Dequantized_Out.stride(1), Dequantized_Out.stride(2),
head_num,head_dim,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_HEAD=BLOCK_HEAD,
num_warps=num_warps,
num_stages=1,
)
@torch.no_grad()
def fp16_to_int8s(fp16_tensor):
fp16_bytes = fp16_tensor.contiguous().view(torch.int8)
int8_high = fp16_bytes[::2] # 高 8 位
int8_low = fp16_bytes[1::2] # 低 8 位
return int8_high.unsqueeze(1), int8_low.unsqueeze(1)
@torch.no_grad()
def int8s_to_fp16(int8_high, int8_low):
fp16_bytes = torch.stack([int8_high, int8_low], dim=-1).view(torch.int16)
return fp16_bytes.view(torch.bfloat16)
def _alltoall(group, input, output_split_sizes, input_split_sizes):
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
def q_alltoall(output, input, output_split_sizes, input_split_sizes,group):
t,s = input.shape[0],input.shape[1]
input_buffer_int8 = torch.empty((t, 1, s), dtype=torch.int8, device="cuda")
buffer_scales = torch.empty((t, 1, 2), dtype=torch.bfloat16, device="cuda")
input_q = input.unsqueeze(1)
destindex_copy_quantize_kv_init_asym(
input_q,
input_buffer_int8,
buffer_scales,
)
input_buffer_int8 = input_buffer_int8.squeeze()
buffer_scales = buffer_scales.squeeze()
buffer_scales_h, buffer_scales_l = fp16_to_int8s(buffer_scales[:,0])
buffer_shift_h, buffer_shift_l = fp16_to_int8s(buffer_scales[:,1])
input_all = torch.cat([input_buffer_int8, buffer_scales_h, buffer_scales_l,buffer_shift_h, buffer_shift_l], dim=1)
torch.distributed.all_to_all_single(
output,
input_all,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
scale = int8s_to_fp16(output[:,-4], output[:,-3])
shift = int8s_to_fp16(output[:,-2], output[:,-1])
scales = torch.cat([scale,shift],dim=1).unsqueeze(1)
deq_out = torch.empty((output.shape[0], 1, output.shape[1]-4), dtype=torch.bfloat16, device="cuda")
destindex_dequantize_kv(output[:,:-4].unsqueeze(1), scales, deq_out)
return deq_out.squeeze()
...@@ -3,8 +3,8 @@ from typing import Optional, Tuple ...@@ -3,8 +3,8 @@ from typing import Optional, Tuple
import torch import torch
from megatron.training import get_args
from megatron.core.tensor_parallel import ( from megatron.core.tensor_parallel import (
all_to_all,
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region,
) )
...@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import ( ...@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
) )
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
from dcu_megatron.core.tensor_parallel import all_to_all
# decouple perbatch state from MoEAlltoAllTokenDispatcher # decouple perbatch state from MoEAlltoAllTokenDispatcher
class MoEAlltoAllPerBatchState: class MoEAlltoAllPerBatchState:
...@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState: ...@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# use_qcomm
args = get_args()
self.use_qcomm = args.use_qcomm
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState): def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr( state.num_global_tokens_per_local_expert = getattr(
self, "num_global_tokens_per_local_expert", None self, "num_global_tokens_per_local_expert", None
...@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
) )
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
) )
return tokens_per_expert, global_input_tokens return tokens_per_expert, global_input_tokens
...@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP] # hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all( permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits self.ep_group, hidden_states, self.input_splits, self.output_splits, use_qcomm=self.use_qcomm
) )
return permutated_local_input_tokens return permutated_local_input_tokens
......
...@@ -10,41 +10,13 @@ from megatron.core.utils import ( ...@@ -10,41 +10,13 @@ from megatron.core.utils import (
deprecate_inference_params, deprecate_inference_params,
make_viewless_tensor, make_viewless_tensor,
) )
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper(
self, is_forward, func, stream, event, *args, skip_detach=False, **kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch.cuda.nvtx.range_push(func.__name__)
event.wait(stream)
with torch.cuda.stream(stream):
outputs = func(*args, **kwargs)
event.record(stream)
if skip_detach:
torch.cuda.nvtx.range_pop()
return outputs
detached_output_tensors = []
if not is_forward:
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
for tensor in outputs:
if tensor is None:
detached_output_tensors.append(None)
elif tensor.dtype.is_floating_point:
detached_output_tensors.append(tensor.detach().requires_grad_(True))
else:
detached_output_tensors.append(tensor.detach())
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def forward( def forward(
self, self,
hidden_states: Tensor, hidden_states: Tensor,
...@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*, *,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
if not isinstance(self.mlp, MoELayer):
return super().forward(
hidden_states=hidden_states,
context=context,
context_mask=context_mask,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
( (
hidden_states, hidden_states,
pre_mlp_layernorm_output, pre_mlp_layernorm_output,
...@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states residual = hidden_states
# Optional Input Layer norm # Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states) if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output_with_bias = self.self_attention( attention_output_with_bias = self.self_attention(
...@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset, sequence_len_offset=sequence_len_offset,
) )
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module? # inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
...@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
) )
# Optional Layer norm post the cross-attention. # Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output) probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare( tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
...@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if shared_expert_output is not None: if shared_expert_output is not None:
output += shared_expert_output output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias) mlp_output_with_bias = (output, mlp_bias)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout mlp_output_with_bias, residual, self.hidden_dropout
...@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_attention_router_compound_dw(self): def _submodule_attention_dw(self):
self.self_attention.backward_dw() self.self_attention.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_attention_router_compound_dw(self):
self._submodule_attention_dw()
def _submodule_mlp_dw(self): def _submodule_mlp_dw(self):
self.mlp.backward_dw() self.mlp.backward_dw()
# raise NotImplementedError("Not implemented")
...@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser): ...@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments # add extra arguments
parser = _add_extra_network_size_args(parser) parser = _add_extra_network_size_args(parser)
parser = _add_extra_training_args(parser) parser = _add_extra_training_args(parser)
parser = _add_extra_initialization_args(parser)
parser = _add_extra_distributed_args(parser) parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser) parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser) parser = _add_extra_moe_args(parser)
...@@ -96,6 +97,14 @@ def _add_extra_training_args(parser): ...@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return parser return parser
def _add_extra_initialization_args(parser):
group = parser.add_argument_group(title='extra initialization args')
group.add_argument('--reproduce', action='store_true',
help='reproduce train loss, need set --seed > 0.')
return parser
def _add_extra_tokenizer_args(parser): def _add_extra_tokenizer_args(parser):
# 删除原参数 # 删除原参数
remove_original_params(parser, ["tokenizer_type"]) remove_original_params(parser, ["tokenizer_type"])
...@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser): ...@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer', 'NullTokenizer',
'DeepSeekV2Tokenizer'], 'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--use-qcomm',
default=False,
action="store_true",
help='use quantized communication')
return parser return parser
......
"""Megatron initialization.""" """Megatron initialization."""
import random
import time import time
import numpy as np
import torch import torch
from datetime import timedelta from datetime import timedelta
from megatron.training import get_args from megatron.training import get_args
from megatron.core import mpu from megatron.core import mpu, tensor_parallel
def _compile_dependencies(): def _compile_dependencies():
...@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process # Call the init process
init_process_group_kwargs = { init_process_group_kwargs = {
'backend' : args.distributed_backend, 'backend': args.distributed_backend,
'world_size': args.world_size, 'world_size': args.world_size,
'rank': args.rank, 'rank': args.rank,
'init_method': args.dist_url, 'init_method': args.dist_url,
...@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f"> initialized pipeline model parallel with size " f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}" f"{mpu.get_pipeline_model_parallel_world_size()}"
) )
def _set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Set random seed for reproducability."""
args = get_args()
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng
)
if args.reproduce:
assert (args.attention_dropout > 0) is False, f"To utilize the reproduction function, args.attention_dropout = {args.attention_dropout} must be set to 0."
assert (args.hidden_dropout > 0) is False, f"To utilize the reproduction function, args.hidden_dropout = {args.hidden_dropout} must be set to 0."
torch.backends.cudnn.deterministic = True # 设置cudnn后端为确定性算法
torch.backends.cudnn.benchmark = False # 固定卷积算法
torch.use_deterministic_algorithms(True) # 使用torch的deterministic算子 避免不确定性
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
...@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import ( ...@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer, _Llama2Tokenizer,
CustomTikTokenizer, CustomTikTokenizer,
_NullTokenizer, _NullTokenizer,
_NullMultimodalTokenizer,
_vocab_size_with_padding _vocab_size_with_padding
) )
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
def build_tokenizer(args, **kwargs): def build_tokenizer(args, **kwargs):
...@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs): ...@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args.tokenizer_prompt_format, args.tokenizer_prompt_format,
args.special_tokens, args.special_tokens,
args.image_tag_type, args.image_tag_type,
args.force_system_message,
) )
elif args.tokenizer_type == 'NullMultimodalTokenizer':
assert args.vocab_size is not None
tokenizer = _NullMultimodalTokenizer(args.vocab_size)
elif args.tokenizer_type == "DeepSeekV2Tokenizer": elif args.tokenizer_type == "DeepSeekV2Tokenizer":
tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size) tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size)
args.padded_vocab_size = tokenizer.vocab_size args.padded_vocab_size = tokenizer.vocab_size
......
...@@ -53,18 +53,9 @@ from megatron.training.training import ( ...@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer = StragglerDetector() stimer = StragglerDetector()
def train( def train(forward_step_func, model, optimizer, opt_param_scheduler,
forward_step_func, train_data_iterator, valid_data_iterator,
model, process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
"""Training function: run train_step desired number of times, run validation, checkpoint.""" """Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -74,10 +65,7 @@ def train( ...@@ -74,10 +65,7 @@ def train(
try: try:
from workload_inspector.utils.webserver import run_server from workload_inspector.utils.webserver import run_server
import threading import threading
threading.Thread(target=run_server, daemon=True, args=(torch.distributed.get_rank(), )).start()
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
except ModuleNotFoundError: except ModuleNotFoundError:
print_rank_0("workload inspector module not found.") print_rank_0("workload inspector module not found.")
...@@ -100,17 +88,11 @@ def train( ...@@ -100,17 +88,11 @@ def train(
rerun_state_machine.current_iteration = iteration rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training. # Track E2E metrics at the start of training.
one_logger_utils.on_train_start( one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
iteration=iteration, train_samples=args.train_samples, seq_length=args.seq_length,
consumed_train_samples=args.consumed_train_samples, train_iters=args.train_iters, save=args.save, async_save=args.async_save,
train_samples=args.train_samples, log_throughput=args.log_throughput,
seq_length=args.seq_length, num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
...@@ -118,10 +100,9 @@ def train( ...@@ -118,10 +100,9 @@ def train(
config.grad_scale_func = optimizer.scale_loss config.grad_scale_func = optimizer.scale_loss
config.timers = timers config.timers = timers
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce: if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, ( assert config.no_sync_func is None, \
'When overlap_grad_reduce is True, config.no_sync_func must be None; ' ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce' 'a custom no_sync_func is not supported when overlapping grad-reduce')
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model] config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1: if len(model) == 1:
config.no_sync_func = config.no_sync_func[0] config.no_sync_func = config.no_sync_func[0]
...@@ -145,9 +126,8 @@ def train( ...@@ -145,9 +126,8 @@ def train(
if args.manual_gc: if args.manual_gc:
# Disable the default garbage collector and perform the collection manually. # Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks. # This is to align the timing of garbage collection across ranks.
assert ( assert args.manual_gc_interval >= 0, \
args.manual_gc_interval >= 0 'Manual garbage collection interval should be larger than or equal to 0'
), 'Manual garbage collection interval should be larger than or equal to 0'
gc.disable() gc.disable()
gc.collect() gc.collect()
...@@ -157,13 +137,10 @@ def train( ...@@ -157,13 +137,10 @@ def train(
world = torch.distributed.get_world_size() world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count mmcnt = args.straggler_minmax_count
stimer.configure( stimer.configure(world, rank,
world, mmcnt = mmcnt,
rank, enabled = not args.disable_straggler_on_startup,
mmcnt=mmcnt, port = args.straggler_ctrlr_port)
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
num_floating_point_operations_since_last_log_event = 0.0 num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
...@@ -171,10 +148,10 @@ def train( ...@@ -171,10 +148,10 @@ def train(
eval_iterations = 0 eval_iterations = 0
def get_e2e_base_metrics(): def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.""" """Get base metrics values for one-logger to calculate E2E tracking metrics.
num_floating_point_operations_since_current_train_start = ( """
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return { return {
'iteration': iteration, 'iteration': iteration,
'train_duration': timers('interval-time').active_time(), 'train_duration': timers('interval-time').active_time(),
...@@ -184,7 +161,7 @@ def train( ...@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far': num_floating_point_operations_so_far, 'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples, 'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size, 'world_size': args.world_size,
'seq_length': args.seq_length, 'seq_length': args.seq_length
} }
# Cache into one-logger for callback. # Cache into one-logger for callback.
if one_logger: if one_logger:
...@@ -192,11 +169,7 @@ def train( ...@@ -192,11 +169,7 @@ def train(
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics) one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None prof = None
if ( if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
def trace_handler(p): def trace_handler(p):
from pathlib import Path from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True) Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
...@@ -242,9 +215,8 @@ def train( ...@@ -242,9 +215,8 @@ def train(
pre_hook_enabled = False pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic. # Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None: if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas( assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
model, cross_check=True "Parameter hashes not matching across DP replicas"
), "Parameter hashes not matching across DP replicas"
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...") print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
...@@ -270,20 +242,14 @@ def train( ...@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid. # to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0: if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, ( assert get_num_microbatches() > num_microbatches, \
f"Number of microbatches should be increasing due to batch size rampup; " (f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}" f"instead going from {num_microbatches} to {get_num_microbatches()}")
)
if args.save is not None: if args.save is not None:
save_checkpoint_and_time( save_checkpoint_and_time(iteration, model, optimizer,
iteration, opt_param_scheduler,
model, num_floating_point_operations_so_far,
optimizer, checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
...@@ -292,9 +258,9 @@ def train( ...@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator. # Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator) dummy_train_step(train_data_iterator)
iteration += 1 iteration += 1
batch_size = ( batch_size = mpu.get_data_parallel_world_size() * \
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() args.micro_batch_size * \
) get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size args.skipped_train_samples += batch_size
continue continue
...@@ -302,28 +268,19 @@ def train( ...@@ -302,28 +268,19 @@ def train(
# Run training step. # Run training step.
args.curr_iteration = iteration args.curr_iteration = iteration
ft_integration.on_training_step_start() ft_integration.on_training_step_start()
( loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
loss_dict, train_step(forward_step_func,
skipped_iter, train_data_iterator,
should_checkpoint, model,
should_exit, optimizer,
exit_code, opt_param_scheduler,
grad_norm, config)
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
ft_integration.on_training_step_end() ft_integration.on_training_step_end()
if should_checkpoint: if should_checkpoint:
save_checkpoint_and_time( save_checkpoint_and_time(iteration, model, optimizer,
iteration, opt_param_scheduler,
model, num_floating_point_operations_so_far,
optimizer, checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
if should_exit: if should_exit:
break break
...@@ -346,13 +303,12 @@ def train( ...@@ -346,13 +303,12 @@ def train(
pre_hook_enabled = True pre_hook_enabled = True
iteration += 1 iteration += 1
batch_size = ( batch_size = mpu.get_data_parallel_world_size() * \
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() args.micro_batch_size * \
) get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = ( num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_global_batch_size() - get_current_running_global_batch_size() get_current_running_global_batch_size())
)
if args.decrease_batch_size_if_needed: if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0 assert num_skipped_samples_in_batch >= 0
else: else:
...@@ -378,22 +334,16 @@ def train( ...@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate = param_group['lr'] decoupled_learning_rate = param_group['lr']
else: else:
learning_rate = param_group['lr'] learning_rate = param_group['lr']
report_memory_flag = training_log( report_memory_flag = training_log(loss_dict, total_loss_dict,
loss_dict, learning_rate,
total_loss_dict, decoupled_learning_rate,
learning_rate, iteration, loss_scale,
decoupled_learning_rate, report_memory_flag, skipped_iter,
iteration, grad_norm, params_norm, num_zeros_in_grad)
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
# Evaluation. # Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop() timers('interval-time').stop()
if should_disable_forward_pre_hook(args): if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
...@@ -403,18 +353,11 @@ def train( ...@@ -403,18 +353,11 @@ def train(
gc.collect() gc.collect()
prefix = f'iteration {iteration}' prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True) timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results( evaluate_and_print_results(prefix, forward_step_func,
prefix, valid_data_iterator, model,
forward_step_func, iteration, process_non_loss_data_func,
valid_data_iterator, config, verbose=False, write_to_tensorboard=True,
model, non_loss_data_func=non_loss_data_func)
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
eval_duration += timers('eval-time').elapsed() eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters eval_iterations += args.eval_iters
timers('eval-time').stop() timers('eval-time').stop()
...@@ -430,25 +373,13 @@ def train( ...@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations. # Some of these only happen at specific iterations.
post_training_step_callbacks( post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
model, num_floating_point_operations_since_last_log_event)
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
# Checkpoint and decide whether to exit. # Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit( should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
model, num_floating_point_operations_so_far,
optimizer, checkpointing_context, train_data_iterator)
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
if should_exit: if should_exit:
break break
...@@ -477,7 +408,6 @@ def train( ...@@ -477,7 +408,6 @@ def train(
if wandb_writer: if wandb_writer:
wandb_writer.finish() wandb_writer.finish()
ft_integration.shutdown() ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code) sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far return iteration, num_floating_point_operations_so_far
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment