Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
c63fceee
Commit
c63fceee
authored
May 19, 2025
by
dongcl
Browse files
Merge branch 'a2a_overlap' into 'core_v0.12.0'
A2a overlap See merge request OpenDAS/dcu_megatron!4
parents
6c3cfb1d
bfe0b4a9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
504 additions
and
191 deletions
+504
-191
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+14
-0
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+9
-0
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+1
-13
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+1
-0
dcu_megatron/core/tensor_parallel/mappings.py
dcu_megatron/core/tensor_parallel/mappings.py
+72
-0
dcu_megatron/core/tensor_parallel/qcomm.py
dcu_megatron/core/tensor_parallel/qcomm.py
+217
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+12
-3
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+53
-34
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+13
-0
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+37
-2
dcu_megatron/training/tokenizer/tokenizer.py
dcu_megatron/training/tokenizer/tokenizer.py
+6
-0
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+69
-139
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
c63fceee
...
@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod
,
staticmethod
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# flux
# flux
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
from
..core.tensor_parallel.layers
import
(
from
..core.tensor_parallel.layers
import
(
...
@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
from
..training.training
import
train
from
..training.initialize
import
_set_random_seed
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
build_tokenizer
)
...
@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
_compile_dependencies
)
# 添加固定seed
MegatronAdaptation
.
register
(
'megatron.training.initialize._set_random_seed'
,
_set_random_seed
)
# add trace_handler
# add trace_handler
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
train
)
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
c63fceee
...
@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
...
@@ -397,6 +397,10 @@ class DenseAttnNode(TransformerLayerNode):
)
)
return
hidden_states
return
hidden_states
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_attention_dw
()
class
FakeScheduleNode
:
class
FakeScheduleNode
:
...
@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
...
@@ -411,6 +415,10 @@ class DenseMlpNode(TransformerLayerNode):
def
forward_impl
(
self
,
hidden_states
):
def
forward_impl
(
self
,
hidden_states
):
return
self
.
layer
.
_submodule_dense_forward
(
hidden_states
)
return
self
.
layer
.
_submodule_dense_forward
(
hidden_states
)
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_mlp_dw
()
def
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
def
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
common_state
=
TransformerLayerState
()
common_state
=
TransformerLayerState
()
...
@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
...
@@ -418,6 +426,7 @@ def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
attn
.
name
=
"attn"
attn
.
name
=
"attn"
dispatch
=
FakeScheduleNode
()
dispatch
=
FakeScheduleNode
()
mlp
=
DenseMlpNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
mlp
=
DenseMlpNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
mlp
.
name
=
"mlp"
combine
=
FakeScheduleNode
()
combine
=
FakeScheduleNode
()
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
c63fceee
...
@@ -7,6 +7,7 @@ from megatron.training import get_args
...
@@ -7,6 +7,7 @@ from megatron.training import get_args
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel.schedules
import
set_current_microbatch
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.utils
import
(
from
megatron.core.utils
import
(
get_attr_wrapped_model
,
get_attr_wrapped_model
,
...
@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
...
@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
def
set_current_microbatch
(
model
,
microbatch_id
):
"""Set the current microbatch."""
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
for
layer
in
decoder
.
layers
:
layer
.
current_microbatch
=
microbatch_id
def
get_pp_rank_microbatches
(
def
get_pp_rank_microbatches
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
):
):
...
...
dcu_megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
c63fceee
from
.mappings
import
all_to_all
\ No newline at end of file
dcu_megatron/core/tensor_parallel/mappings.py
0 → 100644
View file @
c63fceee
import
torch
from
.qcomm
import
q_alltoall
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
group
,
input
,
output_split_sizes
,
input_split_sizes
,
use_qcomm
=
False
):
"""Forward function."""
ctx
.
group
=
group
ctx
.
output_split_sizes
=
output_split_sizes
ctx
.
input_split_sizes
=
input_split_sizes
ctx
.
use_qcomm
=
use_qcomm
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
if
use_qcomm
:
output
=
input
.
new_empty
(
size
=
[
input
.
shape
[
0
],
input
.
shape
[
1
]
+
4
],
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
if
use_qcomm
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
torch
.
int8
,
device
=
torch
.
cuda
.
current_device
(),
)
else
:
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
if
use_qcomm
:
output
=
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
)
else
:
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
@
staticmethod
def
backward
(
ctx
,
*
grad_output
):
"""Backward function."""
return
(
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
,
ctx
.
input_split_sizes
,
ctx
.
output_split_sizes
,
ctx
.
use_qcomm
),
None
,
None
,
None
,
)
def
all_to_all
(
group
,
input_
,
output_split_sizes_
=
None
,
input_split_sizes
=
None
,
use_qcomm
=
False
):
"""Wrapper for autograd function"""
return
_AllToAll
.
apply
(
group
,
input_
,
output_split_sizes_
,
input_split_sizes
,
use_qcomm
)
dcu_megatron/core/tensor_parallel/qcomm.py
0 → 100644
View file @
c63fceee
import
torch
import
triton
import
triton.language
as
tl
import
random
import
unittest
import
json
import
os
import
time
@
triton
.
jit
def
_fwd_kernel_destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
,
stride_k_bs
,
stride_k_h
,
stride_k_d
,
stride_o_bs
,
stride_o_h
,
stride_o_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
src_data
=
tl
.
load
(
K
+
cur_index
*
stride_k_bs
+
offs_h
[:,
None
]
*
stride_k_h
+
stride_k_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
src_data_max
=
tl
.
max
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
src_data_min
=
tl
.
min
(
src_data
,
axis
=
1
,
keep_dims
=
True
)
data_scale
=
(
src_data_max
-
src_data_min
)
/
255.0
data_zero
=
(
-
1
*
src_data_min
/
data_scale
).
to
(
tl
.
int32
)
q_src_data
=
(
tl
.
clamp
((
src_data
/
data_scale
).
to
(
tl
.
int32
).
to
(
tl
.
float32
)
+
data_zero
.
to
(
tl
.
float32
),
0.0
,
255.0
).
to
(
tl
.
int32
)
-
128
).
to
(
tl
.
int8
)
data_scale
=
data_scale
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
data_zero
=
data_zero
.
to
(
Out_scale_zero
.
dtype
.
element_ty
)
o_ptrs
=
Out
+
dest_index
*
stride_o_bs
+
stride_o_h
*
offs_h
[:,
None
]
+
stride_o_d
*
offs_d
[
None
,
:]
os_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
oz_ptrs
=
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
tl
.
store
(
o_ptrs
,
q_src_data
,
mask
=
mask
)
tl
.
store
(
os_ptrs
,
data_scale
,
mask
=
m1
)
tl
.
store
(
oz_ptrs
,
data_zero
,
mask
=
m1
)
@
torch
.
no_grad
()
def
destindex_copy_quantize_kv_init_asym
(
K
,
Out
,
Out_scale_zero
):
bs_seq
=
K
.
shape
[
0
]
head_num
=
K
.
shape
[
1
]
head_dim
=
K
.
shape
[
2
]
assert
K
.
shape
[
1
]
==
Out
.
shape
[
1
]
and
K
.
shape
[
2
]
==
Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_fwd_kernel_destindex_copy_quantize_kv_init_asym
[
grid
](
K
,
Out
,
Out_scale_zero
,
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
Out
.
stride
(
0
),
Out
.
stride
(
1
),
Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
@
triton
.
jit
def
_bwd_kernel_destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
stride_qo_bs
,
stride_qo_h
,
stride_qo_d
,
stride_os_bs
,
stride_os_h
,
stride_os_d
,
stride_do_bs
,
stride_do_h
,
stride_do_d
,
head_num
,
head_dim
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_HEAD
:
tl
.
constexpr
):
cur_index
=
tl
.
program_id
(
0
)
offs_h
=
tl
.
arange
(
0
,
BLOCK_HEAD
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
scales_dtype
=
Out_scale_zero
.
dtype
.
element_ty
dest_index
=
cur_index
m1
=
offs_h
[:,
None
]
<
head_num
m2
=
offs_d
[
None
,:]
<
head_dim
mask
=
m1
&
m2
# Load quantized data
q_data
=
tl
.
load
(
Quantized_Out
+
dest_index
*
stride_qo_bs
+
offs_h
[:,
None
]
*
stride_qo_h
+
stride_qo_d
*
offs_d
[
None
,
:],
mask
=
mask
,
other
=
0
)
# Load scale and zero point
data_scale
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
],
mask
=
m1
,
other
=
1.0
)
data_zero
=
tl
.
load
(
Out_scale_zero
+
dest_index
*
stride_os_bs
+
stride_os_h
*
offs_h
[:,
None
]
+
1
,
mask
=
m1
,
other
=
0
)
# Dequantize
dequantized_data
=
(
q_data
.
to
(
tl
.
int32
)
+
128
-
data_zero
.
to
(
tl
.
int32
)).
to
(
scales_dtype
)
*
data_scale
# Store dequantized data
out_ptrs
=
Dequantized_Out
+
dest_index
*
stride_do_bs
+
stride_do_h
*
offs_h
[:,
None
]
+
stride_do_d
*
offs_d
[
None
,
:]
tl
.
store
(
out_ptrs
,
dequantized_data
,
mask
=
mask
)
@
torch
.
no_grad
()
def
destindex_dequantize_kv
(
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
):
bs_seq
=
Quantized_Out
.
shape
[
0
]
head_num
=
Quantized_Out
.
shape
[
1
]
head_dim
=
Quantized_Out
.
shape
[
2
]
assert
Quantized_Out
.
shape
[
1
]
==
Dequantized_Out
.
shape
[
1
]
and
Quantized_Out
.
shape
[
2
]
==
Dequantized_Out
.
shape
[
2
]
BLOCK_HEAD
=
triton
.
next_power_of_2
(
head_num
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
(
bs_seq
,)
num_warps
=
1
_bwd_kernel_destindex_dequantize_kv
[
grid
](
Quantized_Out
,
Out_scale_zero
,
Dequantized_Out
,
Quantized_Out
.
stride
(
0
),
Quantized_Out
.
stride
(
1
),
Quantized_Out
.
stride
(
2
),
Out_scale_zero
.
stride
(
0
),
Out_scale_zero
.
stride
(
1
),
Out_scale_zero
.
stride
(
2
),
Dequantized_Out
.
stride
(
0
),
Dequantized_Out
.
stride
(
1
),
Dequantized_Out
.
stride
(
2
),
head_num
,
head_dim
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_HEAD
=
BLOCK_HEAD
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
@
torch
.
no_grad
()
def
fp16_to_int8s
(
fp16_tensor
):
fp16_bytes
=
fp16_tensor
.
contiguous
().
view
(
torch
.
int8
)
int8_high
=
fp16_bytes
[::
2
]
# 高 8 位
int8_low
=
fp16_bytes
[
1
::
2
]
# 低 8 位
return
int8_high
.
unsqueeze
(
1
),
int8_low
.
unsqueeze
(
1
)
@
torch
.
no_grad
()
def
int8s_to_fp16
(
int8_high
,
int8_low
):
fp16_bytes
=
torch
.
stack
([
int8_high
,
int8_low
],
dim
=-
1
).
view
(
torch
.
int16
)
return
fp16_bytes
.
view
(
torch
.
bfloat16
)
def
_alltoall
(
group
,
input
,
output_split_sizes
,
input_split_sizes
):
input
=
input
.
contiguous
()
if
output_split_sizes
is
None
:
# Equal split (all2all)
output
=
torch
.
empty_like
(
input
)
else
:
# Unequal split (all2all-v)
output
=
input
.
new_empty
(
size
=
[
sum
(
output_split_sizes
)]
+
list
(
input
.
size
()[
1
:]),
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
distributed
.
all_to_all_single
(
output
,
input
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
return
output
def
q_alltoall
(
output
,
input
,
output_split_sizes
,
input_split_sizes
,
group
):
t
,
s
=
input
.
shape
[
0
],
input
.
shape
[
1
]
input_buffer_int8
=
torch
.
empty
((
t
,
1
,
s
),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
buffer_scales
=
torch
.
empty
((
t
,
1
,
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
input_q
=
input
.
unsqueeze
(
1
)
destindex_copy_quantize_kv_init_asym
(
input_q
,
input_buffer_int8
,
buffer_scales
,
)
input_buffer_int8
=
input_buffer_int8
.
squeeze
()
buffer_scales
=
buffer_scales
.
squeeze
()
buffer_scales_h
,
buffer_scales_l
=
fp16_to_int8s
(
buffer_scales
[:,
0
])
buffer_shift_h
,
buffer_shift_l
=
fp16_to_int8s
(
buffer_scales
[:,
1
])
input_all
=
torch
.
cat
([
input_buffer_int8
,
buffer_scales_h
,
buffer_scales_l
,
buffer_shift_h
,
buffer_shift_l
],
dim
=
1
)
torch
.
distributed
.
all_to_all_single
(
output
,
input_all
,
output_split_sizes
=
output_split_sizes
,
input_split_sizes
=
input_split_sizes
,
group
=
group
,
)
scale
=
int8s_to_fp16
(
output
[:,
-
4
],
output
[:,
-
3
])
shift
=
int8s_to_fp16
(
output
[:,
-
2
],
output
[:,
-
1
])
scales
=
torch
.
cat
([
scale
,
shift
],
dim
=
1
).
unsqueeze
(
1
)
deq_out
=
torch
.
empty
((
output
.
shape
[
0
],
1
,
output
.
shape
[
1
]
-
4
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
destindex_dequantize_kv
(
output
[:,:
-
4
].
unsqueeze
(
1
),
scales
,
deq_out
)
return
deq_out
.
squeeze
()
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
c63fceee
...
@@ -3,8 +3,8 @@ from typing import Optional, Tuple
...
@@ -3,8 +3,8 @@ from typing import Optional, Tuple
import
torch
import
torch
from
megatron.training
import
get_args
from
megatron.core.tensor_parallel
import
(
from
megatron.core.tensor_parallel
import
(
all_to_all
,
gather_from_sequence_parallel_region
,
gather_from_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
)
...
@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
...
@@ -15,6 +15,8 @@ from megatron.core.transformer.moe.moe_utils import (
)
)
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
from
dcu_megatron.core.tensor_parallel
import
all_to_all
# decouple perbatch state from MoEAlltoAllTokenDispatcher
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class
MoEAlltoAllPerBatchState
:
class
MoEAlltoAllPerBatchState
:
...
@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
...
@@ -35,6 +37,13 @@ class MoEAlltoAllPerBatchState:
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# use_qcomm
args
=
get_args
()
self
.
use_qcomm
=
args
.
use_qcomm
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
state
.
num_global_tokens_per_local_expert
=
getattr
(
self
,
"num_global_tokens_per_local_expert"
,
None
self
,
"num_global_tokens_per_local_expert"
,
None
...
@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -125,7 +134,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"before_ep_alltoall"
,
tokens_per_expert
"before_ep_alltoall"
,
tokens_per_expert
)
)
global_input_tokens
=
all_to_all
(
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
,
use_qcomm
=
self
.
use_qcomm
)
)
return
tokens_per_expert
,
global_input_tokens
return
tokens_per_expert
,
global_input_tokens
...
@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -249,7 +258,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
,
use_qcomm
=
self
.
use_qcomm
)
)
return
permutated_local_input_tokens
return
permutated_local_input_tokens
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
c63fceee
...
@@ -10,41 +10,13 @@ from megatron.core.utils import (
...
@@ -10,41 +10,13 @@ from megatron.core.utils import (
deprecate_inference_params
,
deprecate_inference_params
,
make_viewless_tensor
,
make_viewless_tensor
,
)
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_callable_wrapper
(
self
,
is_forward
,
func
,
stream
,
event
,
*
args
,
skip_detach
=
False
,
**
kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch
.
cuda
.
nvtx
.
range_push
(
func
.
__name__
)
event
.
wait
(
stream
)
with
torch
.
cuda
.
stream
(
stream
):
outputs
=
func
(
*
args
,
**
kwargs
)
event
.
record
(
stream
)
if
skip_detach
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
detached_output_tensors
=
[]
if
not
is_forward
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
for
tensor
in
outputs
:
if
tensor
is
None
:
detached_output_tensors
.
append
(
None
)
elif
tensor
.
dtype
.
is_floating_point
:
detached_output_tensors
.
append
(
tensor
.
detach
().
requires_grad_
(
True
))
else
:
detached_output_tensors
.
append
(
tensor
.
detach
())
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Tensor
,
hidden_states
:
Tensor
,
...
@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -61,6 +33,23 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
inference_params
:
Optional
[
Any
]
=
None
,
):
):
if
not
isinstance
(
self
.
mlp
,
MoELayer
):
return
super
().
forward
(
hidden_states
=
hidden_states
,
context
=
context
,
context_mask
=
context_mask
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
(
hidden_states
,
hidden_states
,
pre_mlp_layernorm_output
,
pre_mlp_layernorm_output
,
...
@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -123,7 +112,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual
=
hidden_states
residual
=
hidden_states
# Optional Input Layer norm
# Optional Input Layer norm
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
recompute_input_layernorm
:
self
.
input_layernorm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
input_layernorm_checkpoint
.
checkpoint
(
self
.
input_layernorm
,
hidden_states
)
else
:
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
attention_output_with_bias
=
self
.
self_attention
(
...
@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -138,6 +133,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset
=
sequence_len_offset
,
sequence_len_offset
=
sequence_len_offset
,
)
)
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self
.
input_layernorm_checkpoint
.
discard_output_and_register_recompute
(
attention_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
...
@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -178,7 +180,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
)
# Optional Layer norm post the cross-attention.
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
if
self
.
recompute_pre_mlp_layernorm
:
self
.
pre_mlp_norm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
pre_mlp_norm_checkpoint
.
checkpoint
(
self
.
pre_mlp_layernorm
,
hidden_states
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
...
@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -249,6 +257,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if
shared_expert_output
is
not
None
:
if
shared_expert_output
is
not
None
:
output
+=
shared_expert_output
output
+=
shared_expert_output
mlp_output_with_bias
=
(
output
,
mlp_bias
)
mlp_output_with_bias
=
(
output
,
mlp_bias
)
if
self
.
recompute_pre_mlp_layernorm
:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self
.
pre_mlp_norm_checkpoint
.
discard_output_and_register_recompute
(
mlp_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
...
@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -259,10 +277,11 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
return
output
def
_submodule_attention_
router_compound_
dw
(
self
):
def
_submodule_attention_dw
(
self
):
self
.
self_attention
.
backward_dw
()
self
.
self_attention
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_attention_router_compound_dw
(
self
):
self
.
_submodule_attention_dw
()
def
_submodule_mlp_dw
(
self
):
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
self
.
mlp
.
backward_dw
()
# raise NotImplementedError("Not implemented")
dcu_megatron/training/arguments.py
View file @
c63fceee
...
@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
...
@@ -23,6 +23,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
# add extra arguments
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_initialization_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
...
@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
...
@@ -96,6 +97,14 @@ def _add_extra_training_args(parser):
return
parser
return
parser
def
_add_extra_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'extra initialization args'
)
group
.
add_argument
(
'--reproduce'
,
action
=
'store_true'
,
help
=
'reproduce train loss, need set --seed > 0.'
)
return
parser
def
_add_extra_tokenizer_args
(
parser
):
def
_add_extra_tokenizer_args
(
parser
):
# 删除原参数
# 删除原参数
remove_original_params
(
parser
,
[
"tokenizer_type"
])
remove_original_params
(
parser
,
[
"tokenizer_type"
])
...
@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser):
...
@@ -120,6 +129,10 @@ def _add_extra_tokenizer_args(parser):
'NullTokenizer'
,
'NullTokenizer'
,
'DeepSeekV2Tokenizer'
],
'DeepSeekV2Tokenizer'
],
help
=
'What type of tokenizer to use.'
)
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--use-qcomm'
,
default
=
False
,
action
=
"store_true"
,
help
=
'use quantized communication'
)
return
parser
return
parser
...
...
dcu_megatron/training/initialize.py
View file @
c63fceee
"""Megatron initialization."""
"""Megatron initialization."""
import
random
import
time
import
time
import
numpy
as
np
import
torch
import
torch
from
datetime
import
timedelta
from
datetime
import
timedelta
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
def
_compile_dependencies
():
def
_compile_dependencies
():
...
@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -105,7 +108,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
# Call the init process
init_process_group_kwargs
=
{
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
'backend'
:
args
.
distributed_backend
,
'world_size'
:
args
.
world_size
,
'world_size'
:
args
.
world_size
,
'rank'
:
args
.
rank
,
'rank'
:
args
.
rank
,
'init_method'
:
args
.
dist_url
,
'init_method'
:
args
.
dist_url
,
...
@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -149,3 +152,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f
"> initialized pipeline model parallel with size "
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
)
def
_set_random_seed
(
seed_
:
int
,
data_parallel_random_init
:
bool
=
False
,
te_rng_tracker
:
bool
=
False
,
inference_rng_tracker
:
bool
=
False
,
use_cudagraphable_rng
:
bool
=
False
,
):
"""Set random seed for reproducability."""
args
=
get_args
()
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
,
te_rng_tracker
,
inference_rng_tracker
,
use_cudagraphable_rng
)
if
args
.
reproduce
:
assert
(
args
.
attention_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.attention_dropout =
{
args
.
attention_dropout
}
must be set to 0."
assert
(
args
.
hidden_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.hidden_dropout =
{
args
.
hidden_dropout
}
must be set to 0."
torch
.
backends
.
cudnn
.
deterministic
=
True
# 设置cudnn后端为确定性算法
torch
.
backends
.
cudnn
.
benchmark
=
False
# 固定卷积算法
torch
.
use_deterministic_algorithms
(
True
)
# 使用torch的deterministic算子 避免不确定性
else
:
raise
ValueError
(
"Seed ({}) should be a positive integer."
.
format
(
seed_
))
dcu_megatron/training/tokenizer/tokenizer.py
View file @
c63fceee
...
@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
...
@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer
,
_Llama2Tokenizer
,
CustomTikTokenizer
,
CustomTikTokenizer
,
_NullTokenizer
,
_NullTokenizer
,
_NullMultimodalTokenizer
,
_vocab_size_with_padding
_vocab_size_with_padding
)
)
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
def
build_tokenizer
(
args
,
**
kwargs
):
def
build_tokenizer
(
args
,
**
kwargs
):
...
@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
...
@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args
.
tokenizer_prompt_format
,
args
.
tokenizer_prompt_format
,
args
.
special_tokens
,
args
.
special_tokens
,
args
.
image_tag_type
,
args
.
image_tag_type
,
args
.
force_system_message
,
)
)
elif
args
.
tokenizer_type
==
'NullMultimodalTokenizer'
:
assert
args
.
vocab_size
is
not
None
tokenizer
=
_NullMultimodalTokenizer
(
args
.
vocab_size
)
elif
args
.
tokenizer_type
==
"DeepSeekV2Tokenizer"
:
elif
args
.
tokenizer_type
==
"DeepSeekV2Tokenizer"
:
tokenizer
=
_DeepSeekV2Tokenizer
(
args
.
tokenizer_model
,
args
.
extra_vocab_size
)
tokenizer
=
_DeepSeekV2Tokenizer
(
args
.
tokenizer_model
,
args
.
extra_vocab_size
)
args
.
padded_vocab_size
=
tokenizer
.
vocab_size
args
.
padded_vocab_size
=
tokenizer
.
vocab_size
...
...
dcu_megatron/training/training.py
View file @
c63fceee
...
@@ -53,18 +53,9 @@ from megatron.training.training import (
...
@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer
=
StragglerDetector
()
stimer
=
StragglerDetector
()
def
train
(
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
forward_step_func
,
train_data_iterator
,
valid_data_iterator
,
model
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
,
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -74,10 +65,7 @@ def train(
...
@@ -74,10 +65,7 @@ def train(
try
:
try
:
from
workload_inspector.utils.webserver
import
run_server
from
workload_inspector.utils.webserver
import
run_server
import
threading
import
threading
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),
)).
start
()
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),)
).
start
()
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
print_rank_0
(
"workload inspector module not found."
)
print_rank_0
(
"workload inspector module not found."
)
...
@@ -100,17 +88,11 @@ def train(
...
@@ -100,17 +88,11 @@ def train(
rerun_state_machine
.
current_iteration
=
iteration
rerun_state_machine
.
current_iteration
=
iteration
# Track E2E metrics at the start of training.
# Track E2E metrics at the start of training.
one_logger_utils
.
on_train_start
(
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
iteration
=
iteration
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
train_samples
=
args
.
train_samples
,
log_throughput
=
args
.
log_throughput
,
seq_length
=
args
.
seq_length
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
)
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
,
)
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
...
@@ -118,10 +100,9 @@ def train(
...
@@ -118,10 +100,9 @@ def train(
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
config
.
timers
=
timers
if
isinstance
(
model
[
0
],
(
custom_FSDP
,
DDP
))
and
args
.
overlap_grad_reduce
:
if
isinstance
(
model
[
0
],
(
custom_FSDP
,
DDP
))
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
(
assert
config
.
no_sync_func
is
None
,
\
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
)
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
if
len
(
model
)
==
1
:
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
...
@@ -145,9 +126,8 @@ def train(
...
@@ -145,9 +126,8 @@ def train(
if
args
.
manual_gc
:
if
args
.
manual_gc
:
# Disable the default garbage collector and perform the collection manually.
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
# This is to align the timing of garbage collection across ranks.
assert
(
assert
args
.
manual_gc_interval
>=
0
,
\
args
.
manual_gc_interval
>=
0
'Manual garbage collection interval should be larger than or equal to 0'
),
'Manual garbage collection interval should be larger than or equal to 0'
gc
.
disable
()
gc
.
disable
()
gc
.
collect
()
gc
.
collect
()
...
@@ -157,13 +137,10 @@ def train(
...
@@ -157,13 +137,10 @@ def train(
world
=
torch
.
distributed
.
get_world_size
()
world
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
mmcnt
=
args
.
straggler_minmax_count
mmcnt
=
args
.
straggler_minmax_count
stimer
.
configure
(
stimer
.
configure
(
world
,
rank
,
world
,
mmcnt
=
mmcnt
,
rank
,
enabled
=
not
args
.
disable_straggler_on_startup
,
mmcnt
=
mmcnt
,
port
=
args
.
straggler_ctrlr_port
)
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
,
)
num_floating_point_operations_since_last_log_event
=
0.0
num_floating_point_operations_since_last_log_event
=
0.0
num_microbatches
=
get_num_microbatches
()
num_microbatches
=
get_num_microbatches
()
...
@@ -171,10 +148,10 @@ def train(
...
@@ -171,10 +148,10 @@ def train(
eval_iterations
=
0
eval_iterations
=
0
def
get_e2e_base_metrics
():
def
get_e2e_base_metrics
():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
num_floating_point_operations_since_current_train_start
=
(
"""
num_floating_point_operations_since_current_train_start
=
\
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
)
return
{
return
{
'iteration'
:
iteration
,
'iteration'
:
iteration
,
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
...
@@ -184,7 +161,7 @@ def train(
...
@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'world_size'
:
args
.
world_size
,
'world_size'
:
args
.
world_size
,
'seq_length'
:
args
.
seq_length
,
'seq_length'
:
args
.
seq_length
}
}
# Cache into one-logger for callback.
# Cache into one-logger for callback.
if
one_logger
:
if
one_logger
:
...
@@ -192,11 +169,7 @@ def train(
...
@@ -192,11 +169,7 @@ def train(
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
prof
=
None
prof
=
None
if
(
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
):
def
trace_handler
(
p
):
def
trace_handler
(
p
):
from
pathlib
import
Path
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -242,9 +215,8 @@ def train(
...
@@ -242,9 +215,8 @@ def train(
pre_hook_enabled
=
False
pre_hook_enabled
=
False
# Also, check weight hash across DP replicas to be very pedantic.
# Also, check weight hash across DP replicas to be very pedantic.
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
assert
check_param_hashes_across_dp_replicas
(
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
\
model
,
cross_check
=
True
"Parameter hashes not matching across DP replicas"
),
"Parameter hashes not matching across DP replicas"
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
...
@@ -270,20 +242,14 @@ def train(
...
@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid.
# to make sure training configuration is still valid.
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
(
assert
get_num_microbatches
()
>
num_microbatches
,
\
f
"Number of microbatches should be increasing due to batch size rampup; "
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
)
if
args
.
save
is
not
None
:
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
iteration
,
opt_param_scheduler
,
model
,
num_floating_point_operations_so_far
,
optimizer
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
num_microbatches
=
get_num_microbatches
()
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
...
@@ -292,9 +258,9 @@ def train(
...
@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator.
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step
(
train_data_iterator
)
dummy_train_step
(
train_data_iterator
)
iteration
+=
1
iteration
+=
1
batch_size
=
(
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
args
.
micro_batch_size
*
\
)
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
args
.
consumed_train_samples
+=
batch_size
args
.
skipped_train_samples
+=
batch_size
args
.
skipped_train_samples
+=
batch_size
continue
continue
...
@@ -302,28 +268,19 @@ def train(
...
@@ -302,28 +268,19 @@ def train(
# Run training step.
# Run training step.
args
.
curr_iteration
=
iteration
args
.
curr_iteration
=
iteration
ft_integration
.
on_training_step_start
()
ft_integration
.
on_training_step_start
()
(
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
loss_dict
,
train_step
(
forward_step_func
,
skipped_iter
,
train_data_iterator
,
should_checkpoint
,
model
,
should_exit
,
optimizer
,
exit_code
,
opt_param_scheduler
,
grad_norm
,
config
)
num_zeros_in_grad
,
)
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
ft_integration
.
on_training_step_end
()
ft_integration
.
on_training_step_end
()
if
should_checkpoint
:
if
should_checkpoint
:
save_checkpoint_and_time
(
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
iteration
,
opt_param_scheduler
,
model
,
num_floating_point_operations_so_far
,
optimizer
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
if
should_exit
:
if
should_exit
:
break
break
...
@@ -346,13 +303,12 @@ def train(
...
@@ -346,13 +303,12 @@ def train(
pre_hook_enabled
=
True
pre_hook_enabled
=
True
iteration
+=
1
iteration
+=
1
batch_size
=
(
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
args
.
micro_batch_size
*
\
)
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_global_batch_size
()
-
get_current_running_global_batch_size
()
get_current_running_global_batch_size
())
)
if
args
.
decrease_batch_size_if_needed
:
if
args
.
decrease_batch_size_if_needed
:
assert
num_skipped_samples_in_batch
>=
0
assert
num_skipped_samples_in_batch
>=
0
else
:
else
:
...
@@ -378,22 +334,16 @@ def train(
...
@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate
=
param_group
[
'lr'
]
decoupled_learning_rate
=
param_group
[
'lr'
]
else
:
else
:
learning_rate
=
param_group
[
'lr'
]
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
loss_dict
,
learning_rate
,
total_loss_dict
,
decoupled_learning_rate
,
learning_rate
,
iteration
,
loss_scale
,
decoupled_learning_rate
,
report_memory_flag
,
skipped_iter
,
iteration
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
,
)
# Evaluation.
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
args
.
do_valid
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
args
.
do_valid
:
timers
(
'interval-time'
).
stop
()
timers
(
'interval-time'
).
stop
()
if
should_disable_forward_pre_hook
(
args
):
if
should_disable_forward_pre_hook
(
args
):
disable_forward_pre_hook
(
model
)
disable_forward_pre_hook
(
model
)
...
@@ -403,18 +353,11 @@ def train(
...
@@ -403,18 +353,11 @@ def train(
gc
.
collect
()
gc
.
collect
()
prefix
=
f
'iteration
{
iteration
}
'
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
evaluate_and_print_results
(
prefix
,
forward_step_func
,
prefix
,
valid_data_iterator
,
model
,
forward_step_func
,
iteration
,
process_non_loss_data_func
,
valid_data_iterator
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
model
,
non_loss_data_func
=
non_loss_data_func
)
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
,
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
timers
(
'eval-time'
).
stop
()
...
@@ -430,25 +373,13 @@ def train(
...
@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
model
,
num_floating_point_operations_since_last_log_event
)
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
,
)
# Checkpoint and decide whether to exit.
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
model
,
num_floating_point_operations_so_far
,
optimizer
,
checkpointing_context
,
train_data_iterator
)
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
,
)
if
should_exit
:
if
should_exit
:
break
break
...
@@ -477,7 +408,6 @@ def train(
...
@@ -477,7 +408,6 @@ def train(
if
wandb_writer
:
if
wandb_writer
:
wandb_writer
.
finish
()
wandb_writer
.
finish
()
ft_integration
.
shutdown
()
ft_integration
.
shutdown
()
one_logger_utils
.
finish
()
sys
.
exit
(
exit_code
)
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
return
iteration
,
num_floating_point_operations_so_far
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment