Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
c63fceee
Commit
c63fceee
authored
May 19, 2025
by
dongcl
Browse files
Merge branch 'a2a_overlap' into 'core_v0.12.0'
A2a overlap See merge request OpenDAS/dcu_megatron!4
parents
6c3cfb1d
bfe0b4a9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
504 additions
and
191 deletions
+504
-191
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+14
-0
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+9
-0
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+1
-13
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+1
-0
dcu_megatron/core/tensor_parallel/mappings.py
dcu_megatron/core/tensor_parallel/mappings.py
+72
-0
dcu_megatron/core/tensor_parallel/qcomm.py
dcu_megatron/core/tensor_parallel/qcomm.py
+217
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+12
-3
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+53
-34
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+13
-0
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+37
-2
dcu_megatron/training/tokenizer/tokenizer.py
dcu_megatron/training/tokenizer/tokenizer.py
+6
-0
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+69
-139
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
c63fceee
...
...
@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod
,
apply_wrapper
=
True
)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# flux
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
from
..core.tensor_parallel.layers
import
(
...
...
@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
from
..training.initialize
import
_set_random_seed
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
...
...
@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
# 添加固定seed
MegatronAdaptation
.
register
(
'megatron.training.initialize._set_random_seed'
,
_set_random_seed
)
# add trace_handler
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
c63fceee
...
...
@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
)
return
hidden_states
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_attention_dw
()
class
FakeScheduleNode
:
...
...
@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
def
forward_impl
(
self
,
hidden_states
):
return
self
.
layer
.
_submodule_dense_forward
(
hidden_states
)
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_mlp_dw
()
def
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
common_state
=
TransformerLayerState
()
...
...
@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
attn
.
name
=
"attn"
dispatch
=
FakeScheduleNode
()
mlp
=
DenseMlpNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
mlp
.
name
=
"mlp"
combine
=
FakeScheduleNode
()
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
c63fceee
...
...
@@ -7,6 +7,7 @@ from megatron.training import get_args
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel.schedules
import
set_current_microbatch
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.utils
import
(
get_attr_wrapped_model
,
...
...
@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
def
set_current_microbatch
(
model
,
microbatch_id
):
"""Set the current microbatch."""
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
for
layer
in
decoder
.
layers
:
layer
.
current_microbatch
=
microbatch_id
def
get_pp_rank_microbatches
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
):
...
...
dcu_megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
c63fceee
from
.mappings
import
all_to_all
\ No newline at end of file
dcu_megatron/core/tensor_parallel/mappings.py
0 → 100644
View file @
c63fceee
import
torch
from
.qcomm
import
q_alltoall
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
,
use_qcomm
=
False
):
"""Forward function."""
ctx
.
group
=
group
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
use_qcomm
=
use_qcomm
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
if
use_qcomm
:
output
=
input
.
new_empty
(
size
=
[
input
.
shape
[
0
],
input
.
shape
[
1
]
+
4
],
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
if
use_qcomm
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
use_qcomm
:
output
=
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
)
else
:
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
@
staticmethod
def
backward
(
ctx
,
*
grad_output
):
"""Backward function."""
return
(
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
input_split_sizes
,
ctx
.
output_split_sizes
,
ctx
.
use_qcomm
),
None
,
None
,
None
,
)
def
all_to_all
(
group
,
input_
,
output_split_sizes_
=
None
,
input_split_sizes
=
None
,
use_qcomm
=
False
):
"""Wrapper for autograd function"""
return
_AllToAll
.
apply
(
group
,
input_
,
output_split_sizes_
,
input_split_sizes
,
use_qcomm
)
dcu_megatron/core/tensor_parallel/qcomm.py
0 → 100644
View file @
c63fceee
import
torch
import
triton
import
triton.language
as
tl
import
random
import
unittest
import
json
import
os
import
time
@
triton
.
jit
def
_fwd_kernel_destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
,
stride_k_bs
,
stride_k_h
,
stride_k_d
,
stride_o_bs
,
stride_o_h
,
stride_o_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
src_data
=
tl
.
load
(
K
+
cur_index
*
stride_k_bs
+
offs_h
[:,
None
]
*
stride_k_h
+
stride_k_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
src_data_max
=
tl
.
max
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
src_data_min
=
tl
.
min
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
data_scale
=
(
src_data_max
-
src_data_min
)
/
255.0
data_zero
=
(
-
1
*
src_data_min
/
data_scale
).
to
(
tl
.
int32
)
q_src_data
=
(
tl
.
clamp
((
src_data
/
data_scale
).
to
(
tl
.
int32
).
to
(
tl
.
float32
)
+
data_zero
.
to
(
tl
.
float32
),
0.0
,
255.0
).
to
(
tl
.
int32
)
-
128
).
to
(
tl
.
int8
)
data_scale
=
data_scale
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
data_zero
=
data_zero
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
o_ptrs
=
Out
+
dest_index
*
stride_o_bs
+
stride_o_h
*
offs_h
[:,
None
]
+
stride_o_d
*
offs_d
[
None
,
:]
os_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
oz_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
tl
.
store
(
o_ptrs
,
q_src_data
,
mask
=
mask
)
tl
.
store
(
os_ptrs
,
data_scale
,
mask
=
m1
)
tl
.
store
(
oz_ptrs
,
data_zero
,
mask
=
m1
)
@
torch
.
no_grad
()
def
destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
):
bs_seq
=
K
.
shape
[
0
]
head_num
=
K
.
shape
[
1
]
head_dim
=
K
.
shape
[
2
]
assert
K
.
shape
[
1
]
==
Out
.
shape
[
1
]
and
K
.
shape
[
2
]
==
Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_fwd_kernel_destindex_copy_quantize_kv_init_asym
[
grid
](
K
,
Out
,
Out_scale_zero
,
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
Out
.
stride
(
0
),
Out
.
stride
(
1
),
Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
@
triton
.
jit
def
_bwd_kernel_destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
stride_qo_bs
,
stride_qo_h
,
stride_qo_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
stride_do_bs
,
stride_do_h
,
stride_do_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
scales_dtype
=
Out_scale_zero
.
dtype
.
element_ty
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
# Load quantized data
q_data
=
tl
.
load
(
Quantized_Out
+
dest_index
*
stride_qo_bs
+
offs_h
[:,
None
]
*
stride_qo_h
+
stride_qo_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0
)
# Load scale and zero point
data_scale
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
],
mask
=
m1
,
other
=
1.0
)
data_zero
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
,
mask
=
m1
,
other
=
0
)
# Dequantize
dequantized_data
=
(
q_data
.
to
(
tl
.
int32
)
+
128
-
data_zero
.
to
(
tl
.
int32
)).
to
(
scales_dtype
)
*
data_scale
# Store dequantized data
out_ptrs
=
Dequantized_Out
+
dest_index
*
stride_do_bs
+
stride_do_h
*
offs_h
[:,
None
]
+
stride_do_d
*
offs_d
[
None
,
:]
tl
.
store
(
out_ptrs
,
dequantized_data
,
mask
=
mask
)
@
torch
.
no_grad
()
def
destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
):
bs_seq
=
Quantized_Out
.
shape
[
0
]
head_num
=
Quantized_Out
.
shape
[
1
]
head_dim
=
Quantized_Out
.
shape
[
2
]
assert
Quantized_Out
.
shape
[
1
]
==
Dequantized_Out
.
shape
[
1
]
and
Quantized_Out
.
shape
[
2
]
==
Dequantized_Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_bwd_kernel_destindex_dequantize_kv
[
grid
](
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
Quantized_Out
.
stride
(
0
),
Quantized_Out
.
stride
(
1
),
Quantized_Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
Dequantized_Out
.
stride
(
0
),
Dequantized_Out
.
stride
(
1
),
Dequantized_Out
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
@
torch
.
no_grad
()
def
fp16_to_int8s
(
fp16_tensor
):
fp16_bytes
=
fp16_tensor
.
contiguous
().
view
(
torch
.
int8
)
int8_high
=
fp16_bytes
[::
2
]
# 高 8 位
int8_low
=
fp16_bytes
[
1
::
2
]
# 低 8 位
return
int8_high
.
unsqueeze
(
1
),
int8_low
.
unsqueeze
(
1
)
@
torch
.
no_grad
()
def
int8s_to_fp16
(
int8_high
,
int8_low
):
fp16_bytes
=
torch
.
stack
([
int8_high
,
int8_low
],
dim
=-
1
).
view
(
torch
.
int16
)
return
fp16_bytes
.
view
(
torch
.
bfloat16
)
def
_alltoall
(
group
,
input
,
output_split_sizes
,
input_split_sizes
):
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
def
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
):
t
,
s
=
input
.
shape
[
0
],
input
.
shape
[
1
]
input_buffer_int8
=
torch
.
empty
((
t
,
1
,
s
),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
buffer_scales
=
torch
.
empty
((
t
,
1
,
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
input_q
=
input
.
unsqueeze
(
1
)
destindex_copy_quantize_kv_init_asym
(
input_q
,
input_buffer_int8
,
buffer_scales
,
)
input_buffer_int8
=
input_buffer_int8
.
squeeze
()
buffer_scales
=
buffer_scales
.
squeeze
()
buffer_scales_h
,
buffer_scales_l
=
fp16_to_int8s
(
buffer_scales
[:,
0
])
buffer_shift_h
,
buffer_shift_l
=
fp16_to_int8s
(
buffer_scales
[:,
1
])
input_all
=
torch
.
cat
([
input_buffer_int8
,
buffer_scales_h
,
buffer_scales_l
,
buffer_shift_h
,
buffer_shift_l
],
dim
=
1
)
torch
.
distributed
.
all_to_all_single
(
output
,
input_all
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
scale
=
int8s_to_fp16
(
output
[:,
-
4
],
output
[:,
-
3
])
shift
=
int8s_to_fp16
(
output
[:,
-
2
],
output
[:,
-
1
])
scales
=
torch
.
cat
([
scale
,
shift
],
dim
=
1
).
unsqueeze
(
1
)
deq_out
=
torch
.
empty
((
output
.
shape
[
0
],
1
,
output
.
shape
[
1
]
-
4
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
destindex_dequantize_kv
(
output
[:,:
-
4
].
unsqueeze
(
1
),
scales
,
deq_out
)
return
deq_out
.
squeeze
()
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
c63fceee
...
...
@@ -3,8 +3,8 @@ from typing import Optional, Tuple
import
torch
from
megatron.training
import
get_args
from
megatron.core.tensor_parallel
import
(
all_to_all
,
gather_from_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
...
...
@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
)
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
dcu_megatron.core.tensor_parallel
import
all_to_all
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class
MoEAlltoAllPerBatchState
:
...
...
@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# use_qcomm
args
=
get_args
()
self
.
use_qcomm
=
args
.
use_qcomm
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
self
,
"num_global_tokens_per_local_expert"
,
None
...
...
@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall"
,
tokens_per_expert
)
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
)
return
tokens_per_expert
,
global_input_tokens
...
...
@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
,
use_qcomm
=
self
.
use_qcomm
)
return
permutated_local_input_tokens
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
c63fceee
...
...
@@ -10,41 +10,13 @@ from megatron.core.utils import (
deprecate_inference_params
,
make_viewless_tensor
,
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_callable_wrapper
(
self
,
is_forward
,
func
,
stream
,
event
,
*
args
,
skip_detach
=
False
,
**
kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch
.
cuda
.
nvtx
.
range_push
(
func
.
__name__
)
event
.
wait
(
stream
)
with
torch
.
cuda
.
stream
(
stream
):
outputs
=
func
(
*
args
,
**
kwargs
)
event
.
record
(
stream
)
if
skip_detach
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
detached_output_tensors
=
[]
if
not
is_forward
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
for
tensor
in
outputs
:
if
tensor
is
None
:
detached_output_tensors
.
append
(
None
)
elif
tensor
.
dtype
.
is_floating_point
:
detached_output_tensors
.
append
(
tensor
.
detach
().
requires_grad_
(
True
))
else
:
detached_output_tensors
.
append
(
tensor
.
detach
())
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
def
forward
(
self
,
hidden_states
:
Tensor
,
...
...
@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
if
not
isinstance
(
self
.
mlp
,
MoELayer
):
return
super
().
forward
(
hidden_states
=
hidden_states
,
context
=
context
,
context_mask
=
context_mask
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
hidden_states
,
pre_mlp_layernorm_output
,
...
...
@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual
=
hidden_states
# Optional Input Layer norm
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
recompute_input_layernorm
:
self
.
input_layernorm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
input_layernorm_checkpoint
.
checkpoint
(
self
.
input_layernorm
,
hidden_states
)
else
:
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
...
...
@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset
=
sequence_len_offset
,
)
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self
.
input_layernorm_checkpoint
.
discard_output_and_register_recompute
(
attention_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
...
...
@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
if
self
.
recompute_pre_mlp_layernorm
:
self
.
pre_mlp_norm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
pre_mlp_norm_checkpoint
.
checkpoint
(
self
.
pre_mlp_layernorm
,
hidden_states
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
...
...
@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if
shared_expert_output
is
not
None
:
output
+=
shared_expert_output
mlp_output_with_bias
=
(
output
,
mlp_bias
)
if
self
.
recompute_pre_mlp_layernorm
:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self
.
pre_mlp_norm_checkpoint
.
discard_output_and_register_recompute
(
mlp_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
...
...
@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
def
_submodule_attention_
router_compound_
dw
(
self
):
def
_submodule_attention_dw
(
self
):
self
.
self_attention
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_attention_router_compound_dw
(
self
):
self
.
_submodule_attention_dw
()
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
# raise NotImplementedError("Not implemented")
dcu_megatron/training/arguments.py
View file @
c63fceee
...
...
@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_initialization_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
...
...
@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return
parser
def
_add_extra_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'extra initialization args'
)
group
.
add_argument
(
'--reproduce'
,
action
=
'store_true'
,
help
=
'reproduce train loss, need set --seed > 0.'
)
return
parser
def
_add_extra_tokenizer_args
(
parser
):
# 删除原参数
remove_original_params
(
parser
,
[
"tokenizer_type"
])
...
...
@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer'
,
'DeepSeekV2Tokenizer'
],
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--use-qcomm'
,
default
=
False
,
action
=
"store_true"
,
help
=
'use quantized communication'
)
return
parser
...
...
dcu_megatron/training/initialize.py
View file @
c63fceee
"""Megatron initialization."""
import
random
import
time
import
numpy
as
np
import
torch
from
datetime
import
timedelta
from
megatron.training
import
get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
def
_compile_dependencies
():
...
...
@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
'backend'
:
args
.
distributed_backend
,
'world_size'
:
args
.
world_size
,
'rank'
:
args
.
rank
,
'init_method'
:
args
.
dist_url
,
...
...
@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
def
_set_random_seed
(
seed_
:
int
,
data_parallel_random_init
:
bool
=
False
,
te_rng_tracker
:
bool
=
False
,
inference_rng_tracker
:
bool
=
False
,
use_cudagraphable_rng
:
bool
=
False
,
):
"""Set random seed for reproducability."""
args
=
get_args
()
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
,
te_rng_tracker
,
inference_rng_tracker
,
use_cudagraphable_rng
)
if
args
.
reproduce
:
assert
(
args
.
attention_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.attention_dropout =
{
args
.
attention_dropout
}
must be set to 0."
assert
(
args
.
hidden_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.hidden_dropout =
{
args
.
hidden_dropout
}
must be set to 0."
torch
.
backends
.
cudnn
.
deterministic
=
True
# 设置cudnn后端为确定性算法
torch
.
backends
.
cudnn
.
benchmark
=
False
# 固定卷积算法
torch
.
use_deterministic_algorithms
(
True
)
# 使用torch的deterministic算子 避免不确定性
else
:
raise
ValueError
(
"Seed ({}) should be a positive integer."
.
format
(
seed_
))
dcu_megatron/training/tokenizer/tokenizer.py
View file @
c63fceee
...
...
@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer
,
CustomTikTokenizer
,
_NullTokenizer
,
_NullMultimodalTokenizer
,
_vocab_size_with_padding
)
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
def
build_tokenizer
(
args
,
**
kwargs
):
...
...
@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args
.
tokenizer_prompt_format
,
args
.
special_tokens
,
args
.
image_tag_type
,
args
.
force_system_message
,
)
elif
args
.
tokenizer_type
==
'NullMultimodalTokenizer'
:
assert
args
.
vocab_size
is
not
None
tokenizer
=
_NullMultimodalTokenizer
(
args
.
vocab_size
)
elif
args
.
tokenizer_type
==
"DeepSeekV2Tokenizer"
:
tokenizer
=
_DeepSeekV2Tokenizer
(
args
.
tokenizer_model
,
args
.
extra_vocab_size
)
args
.
padded_vocab_size
=
tokenizer
.
vocab_size
...
...
dcu_megatron/training/training.py
View file @
c63fceee
...
...
@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer
=
StragglerDetector
()
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
,
):
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -74,10 +65,7 @@ def train(
try
:
from
workload_inspector.utils.webserver
import
run_server
import
threading
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),)
).
start
()
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),
)).
start
()
except
ModuleNotFoundError
:
print_rank_0
(
"workload inspector module not found."
)
...
...
@@ -100,17 +88,11 @@ def train(
rerun_state_machine
.
current_iteration
=
iteration
# Track E2E metrics at the start of training.
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
,
)
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
)
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
...
...
@@ -118,10 +100,9 @@ def train(
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
if
isinstance
(
model
[
0
],
(
custom_FSDP
,
DDP
))
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
assert
config
.
no_sync_func
is
None
,
\
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
...
...
@@ -145,9 +126,8 @@ def train(
if
args
.
manual_gc
:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert
(
args
.
manual_gc_interval
>=
0
),
'Manual garbage collection interval should be larger than or equal to 0'
assert
args
.
manual_gc_interval
>=
0
,
\
'Manual garbage collection interval should be larger than or equal to 0'
gc
.
disable
()
gc
.
collect
()
...
...
@@ -157,13 +137,10 @@ def train(
world
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
mmcnt
=
args
.
straggler_minmax_count
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
,
)
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
)
num_floating_point_operations_since_last_log_event
=
0.0
num_microbatches
=
get_num_microbatches
()
...
...
@@ -171,10 +148,10 @@ def train(
eval_iterations
=
0
def
get_e2e_base_metrics
():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start
=
(
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start
=
\
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
)
return
{
'iteration'
:
iteration
,
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
...
...
@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'world_size'
:
args
.
world_size
,
'seq_length'
:
args
.
seq_length
,
'seq_length'
:
args
.
seq_length
}
# Cache into one-logger for callback.
if
one_logger
:
...
...
@@ -192,11 +169,7 @@ def train(
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
prof
=
None
if
(
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
):
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
def
trace_handler
(
p
):
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -242,9 +215,8 @@ def train(
pre_hook_enabled
=
False
# Also, check weight hash across DP replicas to be very pedantic.
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
"Parameter hashes not matching across DP replicas"
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
\
"Parameter hashes not matching across DP replicas"
torch
.
distributed
.
barrier
()
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
...
...
@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid.
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
assert
get_num_microbatches
()
>
num_microbatches
,
\
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
...
...
@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step
(
train_data_iterator
)
iteration
+=
1
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
args
.
skipped_train_samples
+=
batch_size
continue
...
...
@@ -302,28 +268,19 @@ def train(
# Run training step.
args
.
curr_iteration
=
iteration
ft_integration
.
on_training_step_start
()
(
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
,
)
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
ft_integration
.
on_training_step_end
()
if
should_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
if
should_exit
:
break
...
...
@@ -346,13 +303,12 @@ def train(
pre_hook_enabled
=
True
iteration
+=
1
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
()
)
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
())
if
args
.
decrease_batch_size_if_needed
:
assert
num_skipped_samples_in_batch
>=
0
else
:
...
...
@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate
=
param_group
[
'lr'
]
else
:
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
,
)
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
args
.
do_valid
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
args
.
do_valid
:
timers
(
'interval-time'
).
stop
()
if
should_disable_forward_pre_hook
(
args
):
disable_forward_pre_hook
(
model
)
...
...
@@ -403,18 +353,11 @@ def train(
gc
.
collect
()
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
,
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
...
...
@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
,
)
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
)
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
,
)
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
)
if
should_exit
:
break
...
...
@@ -477,7 +408,6 @@ def train(
if
wandb_writer
:
wandb_writer
.
finish
()
ft_integration
.
shutdown
()
one_logger_utils
.
finish
()
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment