Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
4e2de453
Commit
4e2de453
authored
Mar 26, 2025
by
dongcl
Browse files
megatron patch
parent
d77d95c5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1044 additions
and
26 deletions
+1044
-26
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+51
-20
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+7
-4
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+0
-0
dcu_megatron/legacy/model/transformer.py
dcu_megatron/legacy/model/transformer.py
+264
-0
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+251
-1
dcu_megatron/training/tokenizer/tokenizer.py
dcu_megatron/training/tokenizer/tokenizer.py
+96
-1
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+375
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
4e2de453
# coding=utf-8
import
os
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
abc
import
abc
import
sys
import
sys
import
types
import
types
...
@@ -38,15 +24,15 @@ class MegatronAdaptation:
...
@@ -38,15 +24,15 @@ class MegatronAdaptation:
# MegatronAdaptation.post_execute()
# MegatronAdaptation.post_execute()
@
classmethod
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
):
"""
"""
Register adaptations into collection.
Register adaptations into collection.
"""
"""
if
orig_func_name
not
in
cls
.
_patch_info_collection
:
if
orig_func_name
not
in
cls
.
_patch_info_collection
:
from
.patch_utils
import
Patch
from
.patch_utils
import
Patch
cls
.
_patch_info_collection
[
orig_func_name
]
=
Patch
(
orig_func_name
,
new_func
,
create_dummy
)
cls
.
_patch_info_collection
[
orig_func_name
]
=
Patch
(
orig_func_name
,
new_func
,
create_dummy
,
apply_wrapper
=
apply_wrapper
)
else
:
else
:
cls
.
_patch_info_collection
.
get
(
orig_func_name
).
set_patch_func
(
new_func
,
force_patch
)
cls
.
_patch_info_collection
.
get
(
orig_func_name
).
set_patch_func
(
new_func
,
force_patch
,
apply_wrapper
=
apply_wrapper
)
@
classmethod
@
classmethod
def
apply
(
cls
):
def
apply
(
cls
):
...
@@ -138,24 +124,50 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -138,24 +124,50 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.MLATransformerConfig'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.MLATransformerConfig'
,
MLATransformerConfig
)
MLATransformerConfig
)
# Moe
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.unpermute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
def
patch_core_extentions
(
self
):
def
patch_core_extentions
(
self
):
from
..core.extensions.transformer_engine
import
te_dot_product_attention_init
import
transformer_engine
as
te
from
..core.extensions.transformer_engine
import
te_dot_product_attention_init
,
TEGroupedLinear
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
te_dot_product_attention_init
)
te_dot_product_attention_init
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
vocab_parallel_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init
)
vocab_parallel_embedding_init
)
# _VocabParallelCrossEntropy
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
build_tokenizer
)
...
@@ -164,6 +176,10 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -164,6 +176,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
_compile_dependencies
)
# traing.train
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
def
patch_miscellaneous
(
self
):
def
patch_miscellaneous
(
self
):
from
..training.arguments
import
parse_args
from
..training.arguments
import
parse_args
...
@@ -176,7 +192,22 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -176,7 +192,22 @@ class LegacyAdaptation(MegatronAdaptationABC):
"""
"""
def
execute
(
self
):
def
execute
(
self
):
pass
self
.
patch_legacy_models
()
def
patch_legacy_models
(
self
):
from
..legacy.model.transformer
import
ParallelMLP
,
ParallelAttention
# ParallecMLP
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
ParallelMLP
.
__init__
)
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.forward'
,
ParallelAttention
.
forward
)
# rms_norm.RMSNorm
MegatronAdaptation
.
register
(
'megatron.legacy.model.rms_norm.RMSNorm.forward'
,
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
),
apply_wrapper
=
True
)
MegatronAdaptation
.
execute
()
MegatronAdaptation
.
execute
()
dcu_megatron/adaptor/patch_utils.py
View file @
4e2de453
...
@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name):
...
@@ -17,7 +17,7 @@ def dummy_function_wrapper(func_name):
class
Patch
:
class
Patch
:
def
__init__
(
self
,
orig_func_name
,
new_func
,
create_dummy
):
def
__init__
(
self
,
orig_func_name
,
new_func
,
create_dummy
,
apply_wrapper
=
False
):
split_name
=
orig_func_name
.
rsplit
(
'.'
,
1
)
split_name
=
orig_func_name
.
rsplit
(
'.'
,
1
)
if
len
(
split_name
)
==
1
:
if
len
(
split_name
)
==
1
:
self
.
orig_module_name
,
self
.
orig_func_name
=
orig_func_name
,
None
self
.
orig_module_name
,
self
.
orig_func_name
=
orig_func_name
,
None
...
@@ -30,7 +30,7 @@ class Patch:
...
@@ -30,7 +30,7 @@ class Patch:
self
.
wrappers
=
[]
self
.
wrappers
=
[]
if
new_func
is
None
:
if
new_func
is
None
:
new_func
=
dummy_function_wrapper
(
orig_func_name
)
new_func
=
dummy_function_wrapper
(
orig_func_name
)
self
.
set_patch_func
(
new_func
)
self
.
set_patch_func
(
new_func
,
apply_wrapper
=
apply_wrapper
)
self
.
is_applied
=
False
self
.
is_applied
=
False
self
.
create_dummy
=
create_dummy
self
.
create_dummy
=
create_dummy
...
@@ -42,8 +42,11 @@ class Patch:
...
@@ -42,8 +42,11 @@ class Patch:
def
patch_func_id
(
self
):
def
patch_func_id
(
self
):
return
id
(
self
.
patch_func
)
return
id
(
self
.
patch_func
)
def
set_patch_func
(
self
,
new_func
,
force_patch
=
False
):
def
set_patch_func
(
self
,
new_func
,
force_patch
=
False
,
apply_wrapper
=
False
):
if
hasattr
(
new_func
,
'__name__'
)
and
new_func
.
__name__
.
endswith
((
'wrapper'
,
'decorator'
)):
if
(
apply_wrapper
or
(
hasattr
(
new_func
,
'__name__'
)
and
new_func
.
__name__
.
endswith
((
'wrapper'
,
'decorator'
)))
):
self
.
wrappers
.
append
(
new_func
)
self
.
wrappers
.
append
(
new_func
)
else
:
else
:
if
self
.
patch_func
and
not
force_patch
:
if
self
.
patch_func
and
not
force_patch
:
...
...
dcu_megatron/legacy/model/rms_norm.py
0 → 100644
View file @
4e2de453
dcu_megatron/legacy/model/transformer.py
0 → 100644
View file @
4e2de453
import
torch
import
torch.nn.functional
as
F
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.utils
import
(
erf_gelu
,
openai_gelu
,
)
class
ParallelMLP
(
MegatronModule
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
,
is_expert
=
False
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
add_bias
=
config
.
add_bias_linear
ffn_hidden_size
=
config
.
ffn_hidden_size
if
config
.
gated_linear_unit
:
ffn_hidden_size
*=
2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
ffn_hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
skip_bias_add
=
True
,
is_expert
=
is_expert
,
)
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
None
self
.
swiglu
=
args
.
swiglu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
elif
args
.
swiglu
:
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
swiglu
elif
args
.
squared_relu
:
def
squared_relu
(
x
):
return
torch
.
pow
(
F
.
relu
(
x
),
2
)
self
.
activation_func
=
squared_relu
else
:
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
self
.
add_bias
,
skip_bias_add
=
True
,
input_is_parallel
=
True
,
is_expert
=
is_expert
,
)
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step
=
False
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_length
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_query_groups_per_partition
,
(
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
+
2
)
*
self
.
hidden_size_per_attention_head
),
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
[
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
),
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
],
dim
=
3
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer
=
query_layer
.
contiguous
().
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if
rotary_pos_emb
is
not
None
:
if
isinstance
(
rotary_pos_emb
,
tuple
):
rotary_pos_emb
=
rotary_pos_emb
else
:
rotary_pos_emb
=
((
rotary_pos_emb
,)
*
2
)
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# adjust the key rotary positional embedding
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if
not
is_first_step
:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb
=
q_pos_emb
[
sequence_end
-
1
:
sequence_end
]
else
:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb
=
q_pos_emb
[:
sequence_end
,
:,
:,
:]
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
if
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
>
1
:
key_layer
=
key_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
value_layer
=
value_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
q_pos_emb
,
self
.
config
)
key_layer
=
apply_rotary_pos_emb
(
key_layer
,
k_pos_emb
,
self
.
config
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if
not
self
.
use_flash_attn
:
if
self
.
checkpoint_core_attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
q
,
k
,
v
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
if
not
self
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
else
:
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
return
output
,
bias
dcu_megatron/training/arguments.py
View file @
4e2de453
...
@@ -4,7 +4,6 @@ import argparse
...
@@ -4,7 +4,6 @@ import argparse
from
megatron.training.arguments
import
(
from
megatron.training.arguments
import
(
_add_network_size_args
,
_add_network_size_args
,
_add_regularization_args
,
_add_regularization_args
,
_add_training_args
,
_add_initialization_args
,
_add_initialization_args
,
_add_learning_rate_args
,
_add_learning_rate_args
,
_add_checkpointing_args
,
_add_checkpointing_args
,
...
@@ -249,6 +248,8 @@ def _add_tokenizer_args(parser):
...
@@ -249,6 +248,8 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer'
,
'GPTSentencePieceTokenizer'
,
'HuggingFaceTokenizer'
,
'HuggingFaceTokenizer'
,
'Llama2Tokenizer'
,
'Llama2Tokenizer'
,
'Llama3Tokenizer'
,
'QwenTokenizer'
,
'TikTokenizer'
,
'TikTokenizer'
,
'MultimodalTokenizer'
,
'MultimodalTokenizer'
,
'NullTokenizer'
,
'NullTokenizer'
,
...
@@ -265,6 +266,255 @@ def _add_tokenizer_args(parser):
...
@@ -265,6 +266,255 @@ def _add_tokenizer_args(parser):
return
parser
return
parser
def
_add_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'training'
)
group
.
add_argument
(
'--micro-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.'
)
group
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Old batch size parameter, do not use. '
'Use --micro-batch-size instead'
)
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.'
)
group
.
add_argument
(
'--rampup-batch-size'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000
\\
'
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--decrease-batch-size-if-needed'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, decrease batch size if microbatch_size * dp_size'
'does not divide batch_size. Useful for KSO (Keep Soldiering On)'
'to continue making progress if number of healthy GPUs (and'
'corresponding dp_size) does not support current batch_size.'
'Old batch_size will be restored if training is re-started with'
'dp_size that divides batch_size // microbatch_size.'
)
group
.
add_argument
(
'--recompute-activations'
,
action
=
'store_true'
,
help
=
'recompute activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--recompute-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.'
)
group
.
add_argument
(
'--no-check-for-nan-in-loss-and-grad'
,
action
=
'store_false'
,
help
=
'Check for NaNs in loss and grad'
,
dest
=
'check_for_nan_in_loss_and_grad'
)
group
.
add_argument
(
'--check-for-spiky-loss'
,
action
=
'store_true'
,
help
=
'Check for spiky loss'
,
dest
=
'check_for_spiky_loss'
)
group
.
add_argument
(
'--distribute-saved-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute recomputed activations '
'across model parallel group.'
)
group
.
add_argument
(
'--recompute-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers'
)
group
.
add_argument
(
'--recompute-num-layers'
,
type
=
int
,
default
=
None
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.'
)
group
.
add_argument
(
'--no-clone-scatter-output-in-embedding'
,
action
=
'store_false'
,
help
=
'If not set, clone the output of the scatter in embedding layer to GC original tensor.'
,
dest
=
'clone_scatter_output_in_embedding'
)
group
.
add_argument
(
'--profile'
,
action
=
'store_true'
,
help
=
'Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
'nsys commandline is `nsys profile -s none -t nvtx,cuda '
'-o <path/to/output_file> --force-overwrite true '
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.'
)
group
.
add_argument
(
'--profile-step-start'
,
type
=
int
,
default
=
10
,
help
=
'Global step to start profiling.'
)
group
.
add_argument
(
'--profile-step-end'
,
type
=
int
,
default
=
12
,
help
=
'Global step to stop profiling.'
)
group
.
add_argument
(
'--use-pytorch-profiler'
,
action
=
'store_true'
,
help
=
'Use the built-in pytorch profiler. '
'Useful if you wish to view profiles in tensorboard.'
,
dest
=
'use_pytorch_profiler'
)
group
.
add_argument
(
'--profile-ranks'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
0
],
help
=
'Global ranks to profile.'
)
group
.
add_argument
(
'--record-memory-history'
,
action
=
"store_true"
,
default
=
False
,
help
=
'Record memory history in last rank.'
)
group
.
add_argument
(
'--memory-snapshot-path'
,
type
=
str
,
default
=
"snapshot.pickle"
,
help
=
'Specifies where to dump the memory history pickle.'
)
group
.
add_argument
(
'--tp-comm-overlap'
,
action
=
'store_true'
,
help
=
'Enables the '
' overlap of Tensor parallel communication and GEMM kernels.'
)
group
.
add_argument
(
'--tp-comm-overlap-cfg'
,
type
=
str
,
default
=
None
,
help
=
'Config file when tp_comm_overlap is enabled.'
)
group
.
add_argument
(
'--disable-tp-comm-overlap-ag'
,
action
=
'store_false'
,
help
=
(
'Disables the All-Gather overlap with GEMM by '
'pipelining the GEMM and All-Gather.'
),
dest
=
'tp_comm_overlap_ag'
)
group
.
add_argument
(
'--disable-tp-comm-overlap-rs'
,
action
=
'store_false'
,
help
=
(
'Disables the Reduce-Scatter overlap with GEMM by '
'pipelining the GEMM and Reduce-Scatter.'
),
dest
=
'tp_comm_overlap_rs'
)
group
.
add_argument
(
'--tp-comm-overlap-rs-dgrad'
,
action
=
'store_true'
,
help
=
'Enables the Reduce-Scatter overlap with dgrad GEMM.'
,
dest
=
'tp_comm_overlap_rs_dgrad'
)
group
.
add_argument
(
'--disable-tp-comm-bulk-dgrad'
,
action
=
'store_false'
,
help
=
'Disables the All-Gather overlap with bprop activation gradient GEMM.'
,
dest
=
'tp_comm_bulk_dgrad'
)
group
.
add_argument
(
'--disable-tp-comm-bulk-wgrad'
,
action
=
'store_false'
,
help
=
'Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.'
,
dest
=
'tp_comm_bulk_wgrad'
)
group
.
add_argument
(
'--tp-comm-bootstrap-backend'
,
default
=
'nccl'
,
type
=
str
,
choices
=
[
'nccl'
,
'mpi'
,
'gloo'
],
help
=
'Set the bootstrapping backend of Tensor parallel communications.'
)
group
.
add_argument
(
'--use-cpu-initialization'
,
action
=
'store_true'
,
default
=
None
,
help
=
'If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.'
)
group
.
add_argument
(
'--empty-unused-memory-level'
,
default
=
0
,
type
=
int
,
choices
=
[
0
,
1
,
2
],
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
group
.
add_argument
(
'--deterministic-mode'
,
action
=
'store_true'
,
help
=
'Choose code that has deterministic execution. This usually '
'means slower execution, but is good for debugging and testing.'
)
group
.
add_argument
(
'--check-weight-hash-across-dp-replicas-interval'
,
type
=
int
,
default
=
None
,
help
=
'Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.'
)
group
.
add_argument
(
'--calculate-per-token-loss'
,
action
=
'store_true'
,
help
=
(
'Scale cross entropy loss by the number of non-padded tokens in the '
'global batch, versus the default behavior of assuming all tokens are non-padded.'
))
group
.
add_argument
(
'--train-sync-interval'
,
type
=
int
,
default
=
None
,
help
=
'Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.'
)
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--train-samples'
,
type
=
int
,
default
=
None
,
help
=
'Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Report loss and timing interval.'
)
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after the iteration is divisible '
'by this value.'
)
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--exit-signal-handler'
,
action
=
'store_true'
,
help
=
'Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
group
.
add_argument
(
'--no-bias-swiglu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and swiglu fusion, the fusion is '
'available only when using megatron-core.'
,
dest
=
'bias_swiglu_fusion'
)
group
.
add_argument
(
'--no-bias-dropout-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and dropout fusion.'
,
dest
=
'bias_dropout_fusion'
)
group
.
add_argument
(
'--no-rope-fusion'
,
action
=
'store_false'
,
help
=
'Disable rope fusion, the fusion is available '
'only when using megatron-core.'
,
dest
=
'apply_rope_fusion'
)
group
.
add_argument
(
'--cross-entropy-loss-fusion'
,
action
=
'store_true'
,
help
=
'Enabled fusion of cross entropy loss calculation.'
,
dest
=
'cross_entropy_loss_fusion'
)
group
.
add_argument
(
'--use-flash-attn'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--disable-bias-linear'
,
action
=
'store_false'
,
help
=
'Disable bias in the linear layers'
,
dest
=
'add_bias_linear'
)
group
.
add_argument
(
'--add-qkv-bias'
,
action
=
'store_true'
,
help
=
'Enable bias only in the QKV linear layers'
,
dest
=
'add_qkv_bias'
)
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
choices
=
[
'single'
,
'cyclic'
,
'external'
],
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_false'
,
help
=
'DEPRECATED. This flag is ignored.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
help
=
'Enable sequence parallel optimization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
group
.
add_argument
(
'--use-mcore-models'
,
action
=
'store_true'
,
dest
=
'deprecated_use_mcore_models'
,
help
=
'DEPRECATED. Use the implementation from megatron core.'
'Now ignored and mcore models are the default, use '
'--use-legacy-models to not use core models.'
)
group
.
add_argument
(
'--use-legacy-models'
,
action
=
'store_true'
,
help
=
'Use the legacy Megatron models, not Megatron-Core models.'
)
group
.
add_argument
(
'--manual-gc'
,
action
=
'store_true'
,
help
=
'Disable the threshold-based default garbage '
'collector and trigger the garbage collection manually. '
'Manual garbage collection helps to align the timing of '
'the collection across ranks which mitigates the impact '
'of CPU-associated jitters. When the manual gc is enabled, '
'garbage collection is performed only at the start and the '
'end of the validation routine by default.'
)
group
.
add_argument
(
'--manual-gc-interval'
,
type
=
int
,
default
=
0
,
help
=
'Training step interval to trigger manual garbage '
'collection. When the value is set to 0, garbage '
'collection is not triggered between training steps.'
)
group
.
add_argument
(
'--no-manual-gc-eval'
,
action
=
'store_false'
,
help
=
'When using manual garbage collection, disable '
'garbage collection at the start and the end of each '
'evaluation run.'
,
dest
=
'manual_gc_eval'
)
group
.
add_argument
(
'--disable-tp-comm-split-ag'
,
action
=
'store_false'
,
help
=
'Disables the All-Gather overlap with fprop GEMM.'
,
dest
=
'tp_comm_split_ag'
)
group
.
add_argument
(
'--disable-tp-comm-split-rs'
,
action
=
'store_false'
,
help
=
'Disables the Reduce-Scatter overlap with fprop GEMM.'
,
dest
=
'tp_comm_split_rs'
)
group
.
add_argument
(
'--profile-dir'
,
type
=
str
,
default
=
"./"
,
help
=
'profile dir to save.'
)
return
parser
def
_add_mtp_args
(
parser
):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
...
...
dcu_megatron/training/tokenizer/tokenizer.py
View file @
4e2de453
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
,
Qwen2Tokenizer
from
megatron.core.datasets.megatron_tokenizer
import
MegatronTokenizer
from
megatron.core.datasets.megatron_tokenizer
import
MegatronTokenizer
from
megatron.training.tokenizer.tokenizer
import
(
from
megatron.training.tokenizer.tokenizer
import
(
_BertWordPieceTokenizer
,
_BertWordPieceTokenizer
,
...
@@ -46,6 +46,11 @@ def build_tokenizer(args, **kwargs):
...
@@ -46,6 +46,11 @@ def build_tokenizer(args, **kwargs):
elif
args
.
tokenizer_type
==
'Llama2Tokenizer'
:
elif
args
.
tokenizer_type
==
'Llama2Tokenizer'
:
assert
args
.
tokenizer_model
is
not
None
assert
args
.
tokenizer_model
is
not
None
tokenizer
=
_Llama2Tokenizer
(
args
.
tokenizer_model
)
tokenizer
=
_Llama2Tokenizer
(
args
.
tokenizer_model
)
elif
args
.
tokenizer_type
==
'Llama3Tokenizer'
:
assert
args
.
tokenizer_model
is
not
None
tokenizer
=
_Llama3Tokenizer
(
args
.
tokenizer_model
)
elif
args
.
tokenizer_type
==
'QwenTokenizer'
:
tokenizer
=
_Qwen2Tokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
elif
args
.
tokenizer_type
==
'TikTokenizer'
:
elif
args
.
tokenizer_type
==
'TikTokenizer'
:
assert
args
.
tokenizer_model
is
not
None
assert
args
.
tokenizer_model
is
not
None
assert
args
.
tiktoken_pattern
is
not
None
assert
args
.
tiktoken_pattern
is
not
None
...
@@ -101,6 +106,96 @@ def build_tokenizer(args, **kwargs):
...
@@ -101,6 +106,96 @@ def build_tokenizer(args, **kwargs):
return
tokenizer
return
tokenizer
class
_Llama3Tokenizer
(
MegatronTokenizer
):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def
__init__
(
self
,
model_file
):
super
().
__init__
(
model_file
)
from
pathlib
import
Path
import
tiktoken
from
tiktoken.load
import
load_tiktoken_bpe
tokenizer_path
=
model_file
special_tokens
=
[
"<|begin_of_text|>"
,
"<|end_of_text|>"
,
"<|reserved_special_token_0|>"
,
"<|reserved_special_token_1|>"
,
"<|reserved_special_token_2|>"
,
"<|reserved_special_token_3|>"
,
"<|start_header_id|>"
,
"<|end_header_id|>"
,
"<|reserved_special_token_4|>"
,
"<|eot_id|>"
,
# end of turn
]
+
[
f
"<|reserved_special_token_
{
i
}
|>"
for
i
in
range
(
5
,
256
-
5
)]
mergeable_ranks
=
load_tiktoken_bpe
(
tokenizer_path
)
self
.
tokenizer
=
tiktoken
.
Encoding
(
tokenizer_path
,
pat_str
=
r
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
,
mergeable_ranks
=
mergeable_ranks
,
special_tokens
=
{
token
:
len
(
mergeable_ranks
)
+
i
for
i
,
token
in
enumerate
(
special_tokens
)},
)
self
.
eod_id
=
self
.
tokenizer
.
encode
(
"<|end_of_text|>"
,
allowed_special
=
"all"
)[
0
]
@
property
def
vocab_size
(
self
):
return
self
.
tokenizer
.
n_vocab
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encode
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
encode
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
encode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
eod_id
class
_Qwen2Tokenizer
(
MegatronTokenizer
):
def
__init__
(
self
,
vocab_file
,
merge_file
,
extra_vocab_size
=
0
):
super
().
__init__
(
vocab_file
,
merge_file
)
self
.
tokenizer
=
Qwen2Tokenizer
(
vocab_file
,
merge_file
)
self
.
extra_vocab_size
=
extra_vocab_size
self
.
tokenizer
.
add_special_tokens
(
special_tokens_dict
=
dict
(
pad_token
=
"<|extra_0|>"
))
@
property
def
vocab_size
(
self
):
return
len
(
self
.
tokenizer
.
encoder
)
+
self
.
extra_vocab_size
@
property
def
vocab
(
self
):
return
self
.
tokenizer
.
encoder
@
property
def
inv_vocab
(
self
):
return
self
.
tokenizer
.
decoder
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
):
return
self
.
tokenizer
.
decode
(
token_ids
)
@
property
def
eod
(
self
):
return
self
.
tokenizer
.
eos_token_id
@
property
def
eos_token
(
self
):
return
self
.
tokenizer
.
eos_token
@
property
def
pad_token_id
(
self
):
return
self
.
tokenizer
.
pad_token_id
class
_DeepSeekV2Tokenizer
(
MegatronTokenizer
):
class
_DeepSeekV2Tokenizer
(
MegatronTokenizer
):
def
__init__
(
self
,
tokenizer_path
,
extra_vocab_size
):
def
__init__
(
self
,
tokenizer_path
,
extra_vocab_size
):
super
().
__init__
(
tokenizer_path
)
super
().
__init__
(
tokenizer_path
)
...
...
dcu_megatron/training/training.py
0 → 100644
View file @
4e2de453
import
gc
import
sys
import
torch.distributed
import
torch
from
megatron.core
import
mpu
from
megatron.core.utils
import
(
check_param_hashes_across_dp_replicas
,
StragglerDetector
,
)
from
megatron.core.distributed
import
DistributedDataParallel
as
DDP
from
megatron.core.distributed
import
finalize_model_grads
from
megatron.training.initialize
import
write_args_to_tensorboard
from
megatron.core.num_microbatches_calculator
import
(
get_current_global_batch_size
,
get_current_running_global_batch_size
,
get_num_microbatches
,
update_num_microbatches
)
from
megatron.training.async_utils
import
maybe_finalize_async_save
from
megatron.training.utils
import
(
calc_params_l2_norm
,
print_rank_0
,
)
from
megatron.training.global_vars
import
(
get_args
,
get_timers
,
get_tensorboard_writer
,
get_wandb_writer
,
get_one_logger
,
)
from
megatron.training
import
one_logger_utils
from
megatron.training
import
ft_integration
from
megatron.training.training
import
(
print_datetime
,
disable_forward_pre_hook
,
train_step
,
save_checkpoint_and_time
,
enable_forward_pre_hook
,
num_floating_point_operations
,
training_log
,
evaluate_and_print_results
,
post_training_step_callbacks
,
checkpoint_and_decide_exit
,
)
stimer
=
StragglerDetector
()
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args
=
get_args
()
timers
=
get_timers
()
one_logger
=
get_one_logger
()
# Write args to tensorboard
write_args_to_tensorboard
()
# Turn on training mode which enables dropout.
for
model_module
in
model
:
model_module
.
train
()
# Tracking loss.
total_loss_dict
=
{}
# Iterations.
iteration
=
args
.
iteration
# Track E2E metrics at the start of training.
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
)
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
# Setup some training config params.
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
if
isinstance
(
model
[
0
],
DDP
)
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
\
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
if
args
.
align_grad_reduce
:
config
.
grad_sync_func
=
[
model_chunk
.
start_grad_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
grad_sync_func
=
config
.
grad_sync_func
[
0
]
if
args
.
overlap_param_gather
and
args
.
align_param_gather
:
config
.
param_sync_func
=
[
model_chunk
.
start_param_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
param_sync_func
=
config
.
param_sync_func
[
0
]
config
.
finalize_model_grads_func
=
finalize_model_grads
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
pre_hook_enabled
=
False
should_exit
=
False
exit_code
=
0
if
args
.
manual_gc
:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert
args
.
manual_gc_interval
>=
0
,
\
'Manual garbage collection interval should be larger than or equal to 0'
gc
.
disable
()
gc
.
collect
()
# Singleton initialization of straggler detector.
if
args
.
log_straggler
:
global
stimer
world
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
mmcnt
=
args
.
straggler_minmax_count
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
)
num_floating_point_operations_since_last_log_event
=
0.0
num_microbatches
=
get_num_microbatches
()
eval_duration
=
0.0
eval_iterations
=
0
def
get_e2e_base_metrics
():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start
=
\
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
return
{
'iteration'
:
iteration
,
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
'eval_duration'
:
eval_duration
,
'eval_iterations'
:
eval_iterations
,
'total_flops_since_current_train_start'
:
num_floating_point_operations_since_current_train_start
,
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'world_size'
:
args
.
world_size
,
'seq_length'
:
args
.
seq_length
}
# Cache into one-logger for callback.
if
one_logger
:
with
one_logger
.
get_context_manager
():
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
prof
=
None
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
def
trace_handler
(
p
):
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
args
.
rank
in
[
0
]:
print
(
p
.
key_averages
(
group_by_input_shape
=
True
,
group_by_stack_n
=
5
).
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=-
1
,
max_src_column_width
=
100
,
max_name_column_width
=
280
,
max_shapes_column_width
=
200
))
p
.
export_chrome_trace
(
"{path}/trace_rank{rank}_step{step}.json"
.
format
(
path
=
args
.
profile_dir
,
rank
=
torch
.
distributed
.
get_rank
(),
step
=
p
.
step_num
))
prof
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
schedule
=
torch
.
profiler
.
schedule
(
wait
=
max
(
args
.
profile_step_start
-
1
,
0
),
warmup
=
1
if
args
.
profile_step_start
>
0
else
0
,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
repeat
=
1
),
record_shapes
=
True
,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready
=
trace_handler
)
prof
.
start
()
start_iteration
=
iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
disable_forward_pre_hook
(
model
,
param_sync
=
False
)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func
=
config
.
param_sync_func
config
.
param_sync_func
=
None
pre_hook_enabled
=
False
# Also, check weight hash across DP replicas to be very pedantic.
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
\
"Parameter hashes not matching across DP replicas"
torch
.
distributed
.
barrier
()
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
# Run training iterations till done.
while
iteration
<
args
.
train_iters
:
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
if
args
.
use_pytorch_profiler
:
prof
.
step
()
elif
iteration
==
args
.
profile_step_start
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
ft_integration
.
on_checkpointing_start
()
maybe_finalize_async_save
(
blocking
=
False
)
ft_integration
.
on_checkpointing_end
(
is_async_finalization
=
True
)
# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
# from the previous iteration, save a checkpoint. Then run consistency check
# to make sure training configuration is still valid.
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
\
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
# Run training step.
args
.
curr_iteration
=
iteration
ft_integration
.
on_training_step_start
()
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
ft_integration
.
on_training_step_end
()
if
should_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
if
should_exit
:
break
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if
iteration
==
start_iteration
:
if
skipped_iter
:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration
=
iteration
+
1
else
:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
enable_forward_pre_hook
(
model
)
config
.
param_sync_func
=
param_sync_func
pre_hook_enabled
=
True
iteration
+=
1
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
()
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
())
if
args
.
decrease_batch_size_if_needed
:
assert
num_skipped_samples_in_batch
>=
0
else
:
assert
num_skipped_samples_in_batch
==
0
args
.
skipped_train_samples
+=
num_skipped_samples_in_batch
num_floating_point_operations_in_batch
=
num_floating_point_operations
(
args
,
batch_size
)
num_floating_point_operations_so_far
+=
num_floating_point_operations_in_batch
num_floating_point_operations_since_last_log_event
+=
num_floating_point_operations_in_batch
# Logging.
if
not
optimizer
.
is_stub_optimizer
:
loss_scale
=
optimizer
.
get_loss_scale
().
item
()
else
:
loss_scale
=
1.0
params_norm
=
None
if
args
.
log_params_norm
:
params_norm
=
calc_params_l2_norm
(
model
)
learning_rate
=
None
decoupled_learning_rate
=
None
for
param_group
in
optimizer
.
param_groups
:
if
param_group
[
'is_decoupled_lr'
]:
decoupled_learning_rate
=
param_group
[
'lr'
]
else
:
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
args
.
do_valid
:
timers
(
'interval-time'
).
stop
()
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
disable_forward_pre_hook
(
model
)
pre_hook_enabled
=
False
if
args
.
manual_gc
and
args
.
manual_gc_eval
:
# Collect all objects.
gc
.
collect
()
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
one_logger_utils
.
track_e2e_metrics
()
if
args
.
manual_gc
and
args
.
manual_gc_eval
:
# Collect only the objects created and used in evaluation.
gc
.
collect
(
generation
=
0
)
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
enable_forward_pre_hook
(
model
)
pre_hook_enabled
=
True
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
)
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
)
if
should_exit
:
break
one_logger_utils
.
track_e2e_metrics
()
# Flush TensorBoard, WandB writers and one-logger.
writer
=
get_tensorboard_writer
()
if
writer
:
writer
.
flush
()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if
pre_hook_enabled
:
disable_forward_pre_hook
(
model
)
ft_integration
.
on_checkpointing_start
()
maybe_finalize_async_save
(
blocking
=
True
)
ft_integration
.
on_checkpointing_end
(
is_async_finalization
=
True
)
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if
should_exit
:
wandb_writer
=
get_wandb_writer
()
if
wandb_writer
:
wandb_writer
.
finish
()
ft_integration
.
shutdown
()
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment