Commit c63fceee authored by dongcl's avatar dongcl
Browse files

Merge branch 'a2a_overlap' into 'core_v0.12.0'

A2a overlap

See merge request OpenDAS/dcu_megatron!4
parents 6c3cfb1d bfe0b4a9
......@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod,
apply_wrapper=True)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel.layers import (
......@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies
from ..training.training import train
from ..training.initialize import _set_random_seed
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
......@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies)
# 添加固定seed
MegatronAdaptation.register('megatron.training.initialize._set_random_seed',
_set_random_seed)
# add trace_handler
MegatronAdaptation.register('megatron.training.training.train',
train)
......
......@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
)
return hidden_states
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_attention_dw()
class FakeScheduleNode:
......@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
return self.layer._submodule_dense_forward(hidden_states)
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_mlp_dw()
def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream):
common_state = TransformerLayerState()
......@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
attn.name = "attn"
dispatch = FakeScheduleNode()
mlp = DenseMlpNode(chunk_state, common_state, layer, comp_stream, event)
mlp.name = "mlp"
combine = FakeScheduleNode()
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
......
......@@ -7,6 +7,7 @@ from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import (
get_attr_wrapped_model,
......@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
):
......
from .mappings import all_to_all
\ No newline at end of file
import torch
from .qcomm import q_alltoall
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_qcomm=False):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
ctx.use_qcomm = use_qcomm
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
if use_qcomm:
output = input.new_empty(
size=[input.shape[0], input.shape[1]+4],
dtype=torch.int8,
device=torch.cuda.current_device(),
)
else:
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
if use_qcomm:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=torch.int8,
device=torch.cuda.current_device(),
)
else:
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
if use_qcomm:
output = q_alltoall(output, input, output_split_sizes, input_split_sizes,group)
else:
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, *grad_output):
"""Backward function."""
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.use_qcomm),
None,
None,
None,
)
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None, use_qcomm=False):
"""Wrapper for autograd function"""
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_qcomm)
import torch
import triton
import triton.language as tl
import random
import unittest
import json
import os
import time
@triton.jit
def _fwd_kernel_destindex_copy_quantize_kv_init_asym(
K, Out, Out_scale_zero,
stride_k_bs, stride_k_h, stride_k_d,
stride_o_bs, stride_o_h, stride_o_d,
stride_os_bs, stride_os_h, stride_os_d,
head_num,head_dim,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
dest_index = cur_index
m1 = offs_h[:, None] < head_num
m2 = offs_d[None,:] < head_dim
mask = m1&m2
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
mask=mask, other=0.0).to(tl.float32)
src_data_max = tl.max(src_data, axis=1, keep_dims=True)
src_data_min = tl.min(src_data, axis=1, keep_dims=True)
data_scale = (src_data_max - src_data_min) / 255.0
data_zero = (-1 * src_data_min / data_scale).to(tl.int32)
q_src_data = (tl.clamp((src_data / data_scale).to(tl.int32).to(tl.float32) + data_zero.to(tl.float32), 0.0, 255.0).to(tl.int32) - 128).to(tl.int8)
data_scale = data_scale.to(Out_scale_zero.dtype.element_ty)
data_zero = data_zero.to(Out_scale_zero.dtype.element_ty)
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
os_ptrs = Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
oz_ptrs = Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] + 1
tl.store(o_ptrs, q_src_data, mask=mask)
tl.store(os_ptrs, data_scale, mask=m1)
tl.store(oz_ptrs, data_zero, mask=m1)
@torch.no_grad()
def destindex_copy_quantize_kv_init_asym(K, Out, Out_scale_zero):
bs_seq = K.shape[0]
head_num = K.shape[1]
head_dim = K.shape[2]
assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]
BLOCK_HEAD = triton.next_power_of_2(head_num)
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
grid = (bs_seq,)
num_warps = 1
_fwd_kernel_destindex_copy_quantize_kv_init_asym[grid](
K, Out, Out_scale_zero,
K.stride(0), K.stride(1), K.stride(2),
Out.stride(0), Out.stride(1), Out.stride(2),
Out_scale_zero.stride(0), Out_scale_zero.stride(1), Out_scale_zero.stride(2),
head_num,head_dim,
BLOCK_DMODEL= BLOCK_DMODEL,
BLOCK_HEAD=BLOCK_HEAD,
num_warps=num_warps,
num_stages=1,
)
return
@triton.jit
def _bwd_kernel_destindex_dequantize_kv(
Quantized_Out, Out_scale_zero, Dequantized_Out,
stride_qo_bs, stride_qo_h, stride_qo_d,
stride_os_bs, stride_os_h, stride_os_d,
stride_do_bs, stride_do_h, stride_do_d,
head_num,head_dim,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
scales_dtype = Out_scale_zero.dtype.element_ty
dest_index = cur_index
m1 = offs_h[:, None] < head_num
m2 = offs_d[None,:] < head_dim
mask = m1&m2
# Load quantized data
q_data = tl.load(
Quantized_Out + dest_index * stride_qo_bs + offs_h[:, None] * stride_qo_h + stride_qo_d * offs_d[None, :],
mask=mask,
other=0
)
# Load scale and zero point
data_scale = tl.load(
Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None],
mask=m1,
other=1.0
)
data_zero = tl.load(
Out_scale_zero + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] + 1,
mask=m1,
other=0
)
# Dequantize
dequantized_data = (q_data.to(tl.int32) + 128 - data_zero.to(tl.int32)).to(scales_dtype) * data_scale
# Store dequantized data
out_ptrs = Dequantized_Out + dest_index * stride_do_bs + stride_do_h * offs_h[:, None] + stride_do_d * offs_d[None, :]
tl.store(out_ptrs, dequantized_data, mask=mask)
@torch.no_grad()
def destindex_dequantize_kv(Quantized_Out, Out_scale_zero, Dequantized_Out):
bs_seq = Quantized_Out.shape[0]
head_num = Quantized_Out.shape[1]
head_dim = Quantized_Out.shape[2]
assert Quantized_Out.shape[1] == Dequantized_Out.shape[1] and Quantized_Out.shape[2] == Dequantized_Out.shape[2]
BLOCK_HEAD = triton.next_power_of_2(head_num)
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
grid = (bs_seq,)
num_warps = 1
_bwd_kernel_destindex_dequantize_kv[grid](
Quantized_Out, Out_scale_zero, Dequantized_Out,
Quantized_Out.stride(0), Quantized_Out.stride(1), Quantized_Out.stride(2),
Out_scale_zero.stride(0), Out_scale_zero.stride(1), Out_scale_zero.stride(2),
Dequantized_Out.stride(0), Dequantized_Out.stride(1), Dequantized_Out.stride(2),
head_num,head_dim,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_HEAD=BLOCK_HEAD,
num_warps=num_warps,
num_stages=1,
)
@torch.no_grad()
def fp16_to_int8s(fp16_tensor):
fp16_bytes = fp16_tensor.contiguous().view(torch.int8)
int8_high = fp16_bytes[::2] # 高 8 位
int8_low = fp16_bytes[1::2] # 低 8 位
return int8_high.unsqueeze(1), int8_low.unsqueeze(1)
@torch.no_grad()
def int8s_to_fp16(int8_high, int8_low):
fp16_bytes = torch.stack([int8_high, int8_low], dim=-1).view(torch.int16)
return fp16_bytes.view(torch.bfloat16)
def _alltoall(group, input, output_split_sizes, input_split_sizes):
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
def q_alltoall(output, input, output_split_sizes, input_split_sizes,group):
t,s = input.shape[0],input.shape[1]
input_buffer_int8 = torch.empty((t, 1, s), dtype=torch.int8, device="cuda")
buffer_scales = torch.empty((t, 1, 2), dtype=torch.bfloat16, device="cuda")
input_q = input.unsqueeze(1)
destindex_copy_quantize_kv_init_asym(
input_q,
input_buffer_int8,
buffer_scales,
)
input_buffer_int8 = input_buffer_int8.squeeze()
buffer_scales = buffer_scales.squeeze()
buffer_scales_h, buffer_scales_l = fp16_to_int8s(buffer_scales[:,0])
buffer_shift_h, buffer_shift_l = fp16_to_int8s(buffer_scales[:,1])
input_all = torch.cat([input_buffer_int8, buffer_scales_h, buffer_scales_l,buffer_shift_h, buffer_shift_l], dim=1)
torch.distributed.all_to_all_single(
output,
input_all,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
scale = int8s_to_fp16(output[:,-4], output[:,-3])
shift = int8s_to_fp16(output[:,-2], output[:,-1])
scales = torch.cat([scale,shift],dim=1).unsqueeze(1)
deq_out = torch.empty((output.shape[0], 1, output.shape[1]-4), dtype=torch.bfloat16, device="cuda")
destindex_dequantize_kv(output[:,:-4].unsqueeze(1), scales, deq_out)
return deq_out.squeeze()
......@@ -3,8 +3,8 @@ from typing import Optional, Tuple
import torch
from megatron.training import get_args
from megatron.core.tensor_parallel import (
all_to_all,
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
......@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
)
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
from dcu_megatron.core.tensor_parallel import all_to_all
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class MoEAlltoAllPerBatchState:
......@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# use_qcomm
args = get_args()
self.use_qcomm = args.use_qcomm
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr(
self, "num_global_tokens_per_local_expert", None
......@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall", tokens_per_expert
)
global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits, use_qcomm=self.use_qcomm
)
return tokens_per_expert, global_input_tokens
......@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits
self.ep_group, hidden_states, self.input_splits, self.output_splits, use_qcomm=self.use_qcomm
)
return permutated_local_input_tokens
......
......@@ -10,41 +10,13 @@ from megatron.core.utils import (
deprecate_inference_params,
make_viewless_tensor,
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper(
self, is_forward, func, stream, event, *args, skip_detach=False, **kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch.cuda.nvtx.range_push(func.__name__)
event.wait(stream)
with torch.cuda.stream(stream):
outputs = func(*args, **kwargs)
event.record(stream)
if skip_detach:
torch.cuda.nvtx.range_pop()
return outputs
detached_output_tensors = []
if not is_forward:
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
for tensor in outputs:
if tensor is None:
detached_output_tensors.append(None)
elif tensor.dtype.is_floating_point:
detached_output_tensors.append(tensor.detach().requires_grad_(True))
else:
detached_output_tensors.append(tensor.detach())
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def forward(
self,
hidden_states: Tensor,
......@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*,
inference_params: Optional[Any] = None,
):
if not isinstance(self.mlp, MoELayer):
return super().forward(
hidden_states=hidden_states,
context=context,
context_mask=context_mask,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
(
hidden_states,
pre_mlp_layernorm_output,
......@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
......@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
......@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
......@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if shared_expert_output is not None:
output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
......@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_attention_router_compound_dw(self):
def _submodule_attention_dw(self):
self.self_attention.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_attention_router_compound_dw(self):
self._submodule_attention_dw()
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
# raise NotImplementedError("Not implemented")
......@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
parser = _add_extra_network_size_args(parser)
parser = _add_extra_training_args(parser)
parser = _add_extra_initialization_args(parser)
parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser)
......@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return parser
def _add_extra_initialization_args(parser):
group = parser.add_argument_group(title='extra initialization args')
group.add_argument('--reproduce', action='store_true',
help='reproduce train loss, need set --seed > 0.')
return parser
def _add_extra_tokenizer_args(parser):
# 删除原参数
remove_original_params(parser, ["tokenizer_type"])
......@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer',
'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--use-qcomm',
default=False,
action="store_true",
help='use quantized communication')
return parser
......
"""Megatron initialization."""
import random
import time
import numpy as np
import torch
from datetime import timedelta
from megatron.training import get_args
from megatron.core import mpu
from megatron.core import mpu, tensor_parallel
def _compile_dependencies():
......@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs = {
'backend' : args.distributed_backend,
'backend': args.distributed_backend,
'world_size': args.world_size,
'rank': args.rank,
'init_method': args.dist_url,
......@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
def _set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Set random seed for reproducability."""
args = get_args()
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng
)
if args.reproduce:
assert (args.attention_dropout > 0) is False, f"To utilize the reproduction function, args.attention_dropout = {args.attention_dropout} must be set to 0."
assert (args.hidden_dropout > 0) is False, f"To utilize the reproduction function, args.hidden_dropout = {args.hidden_dropout} must be set to 0."
torch.backends.cudnn.deterministic = True # 设置cudnn后端为确定性算法
torch.backends.cudnn.benchmark = False # 固定卷积算法
torch.use_deterministic_algorithms(True) # 使用torch的deterministic算子 避免不确定性
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
......@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer,
CustomTikTokenizer,
_NullTokenizer,
_NullMultimodalTokenizer,
_vocab_size_with_padding
)
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
def build_tokenizer(args, **kwargs):
......@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args.tokenizer_prompt_format,
args.special_tokens,
args.image_tag_type,
args.force_system_message,
)
elif args.tokenizer_type == 'NullMultimodalTokenizer':
assert args.vocab_size is not None
tokenizer = _NullMultimodalTokenizer(args.vocab_size)
elif args.tokenizer_type == "DeepSeekV2Tokenizer":
tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size)
args.padded_vocab_size = tokenizer.vocab_size
......
......@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer = StragglerDetector()
def train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
......@@ -74,10 +65,7 @@ def train(
try:
from workload_inspector.utils.webserver import run_server
import threading
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
threading.Thread(target=run_server, daemon=True, args=(torch.distributed.get_rank(), )).start()
except ModuleNotFoundError:
print_rank_0("workload inspector module not found.")
......@@ -100,17 +88,11 @@ def train(
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(
iteration=iteration,
consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
......@@ -118,10 +100,9 @@ def train(
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, (
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
......@@ -145,9 +126,8 @@ def train(
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert (
args.manual_gc_interval >= 0
), 'Manual garbage collection interval should be larger than or equal to 0'
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()
......@@ -157,13 +137,10 @@ def train(
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(
world,
rank,
mmcnt=mmcnt,
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
stimer.configure(world, rank,
mmcnt = mmcnt,
enabled = not args.disable_straggler_on_startup,
port = args.straggler_ctrlr_port)
num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches()
......@@ -171,10 +148,10 @@ def train(
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start = (
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
......@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length,
'seq_length': args.seq_length
}
# Cache into one-logger for callback.
if one_logger:
......@@ -192,11 +169,7 @@ def train(
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
......@@ -242,9 +215,8 @@ def train(
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(
model, cross_check=True
), "Parameter hashes not matching across DP replicas"
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
......@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, (
f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}"
)
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
......@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
......@@ -302,28 +268,19 @@ def train(
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
(
loss_dict,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
......@@ -346,13 +303,12 @@ def train(
pre_hook_enabled = True
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (
get_current_global_batch_size() - get_current_running_global_batch_size()
)
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
......@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
......@@ -403,18 +353,11 @@ def train(
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
......@@ -477,7 +408,6 @@ def train(
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment