Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
e103a256
Commit
e103a256
authored
Jun 12, 2025
by
dongcl
Browse files
patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919
parent
ade7b0dc
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
390 additions
and
223 deletions
+390
-223
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+7
-1
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+10
-1
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+2
-4
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+58
-52
dcu_megatron/legacy/model/transformer.py
dcu_megatron/legacy/model/transformer.py
+6
-0
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+190
-93
pretrain_gpt.py
pretrain_gpt.py
+117
-72
No files found.
dcu_megatron/core/extensions/transformer_engine.py
View file @
e103a256
...
@@ -10,6 +10,7 @@ from packaging.version import Version as PkgVersion
...
@@ -10,6 +10,7 @@ from packaging.version import Version as PkgVersion
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
...
@@ -20,7 +21,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
...
@@ -20,7 +21,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
get_context_parallel_group
,
get_context_parallel_group
,
get_hierarchical_context_parallel_groups
,
get_hierarchical_context_parallel_groups
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
...
@@ -69,6 +69,8 @@ class TELinear(MegatronCoreTELinear):
...
@@ -69,6 +69,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation
:
bool
,
skip_weight_param_allocation
:
bool
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
args
=
get_args
()
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
...
@@ -90,6 +92,8 @@ class TELinear(MegatronCoreTELinear):
...
@@ -90,6 +92,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation
=
skip_weight_param_allocation
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
is_expert
=
is_expert
,
is_expert
=
is_expert
,
symmetric_ar_type
=
symmetric_ar_type
,
tp_group
=
tp_group
,
)
)
def
backward_dw
(
self
):
def
backward_dw
(
self
):
...
@@ -118,6 +122,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
...
@@ -118,6 +122,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
args
=
get_args
()
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
...
@@ -139,6 +144,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
...
@@ -139,6 +144,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert
=
is_expert
,
is_expert
=
is_expert
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
)
def
backward_dw
(
self
):
def
backward_dw
(
self
):
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
e103a256
...
@@ -282,6 +282,7 @@ def forward_backward_step(
...
@@ -282,6 +282,7 @@ def forward_backward_step(
checkpoint_activations_microbatch
=
None
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
current_microbatch
=
None
,
vp_stage
=
None
,
encoder_decoder_xattn
=
False
,
encoder_decoder_xattn
=
False
,
):
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
...
@@ -345,6 +346,8 @@ def forward_backward_step(
...
@@ -345,6 +346,8 @@ def forward_backward_step(
Whether it is the first microbatch. Defaults to False.
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
current_microbatch (int, optional):
The current microbatch. Defaults to None.
The current microbatch. Defaults to None.
vp_stage (int, optional):
The virtual pipeline stage. Defaults to None.
Returns:
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor or list[Tensor]: The output object(s) from the forward step.
...
@@ -435,13 +438,19 @@ def forward_backward_step(
...
@@ -435,13 +438,19 @@ def forward_backward_step(
num_tokens
=
None
num_tokens
=
None
if
f_model
:
if
f_model
:
with
f_context
:
with
f_context
:
model_vp_stage
=
getattr
(
f_model
,
"vp_stage"
,
None
)
if
vp_stage
is
not
None
and
model_vp_stage
is
not
None
:
assert
(
vp_stage
==
model_vp_stage
),
f
"vp_stage (
{
vp_stage
}
) doesn't match model_vp_stage (
{
model_vp_stage
}
)"
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
args
=
get_args
()
args
=
get_args
()
is_last_stage
=
False
is_last_stage
=
False
if
args
.
schedule_method
==
"dualpipev"
:
if
args
.
schedule_method
==
"dualpipev"
:
is_last_stage
=
parallel_state
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
is_last_stage
=
parallel_state
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
else
:
else
:
is_last_stage
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
is_last_stage
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
,
vp_stage
=
vp_stage
)
if
is_last_stage
:
if
is_last_stage
:
if
not
collect_non_loss_data
:
if
not
collect_non_loss_data
:
loss_node
=
ScheduleNode
(
loss_node
=
ScheduleNode
(
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
e103a256
...
@@ -1111,11 +1111,9 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1111,11 +1111,9 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad
,
_
=
recv_backward
(
tensor_shape
,
config
,
master_chunk_id
)
output_tensor_grad
,
_
=
recv_backward
(
tensor_shape
,
config
,
master_chunk_id
)
output_tensor_grads
[
master_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grads
[
master_chunk_id
].
append
(
output_tensor_grad
)
input_tensor_grad
=
backward_step_helper
(
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
master_chunk_id
,
bwd_model_chunk_id
=
master_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
master_chunk_id
]
)
)
cur_bwd_chunk_microbatch
[
master_chunk_id
]
+=
1
_
=
send_backward
(
_
=
send_backward
(
input_tensor_grad
,
input_tensor_grad
,
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
e103a256
import
contextlib
import
contextlib
from
typing
import
Iterator
,
List
,
Union
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -12,6 +13,8 @@ from megatron.core.utils import (
...
@@ -12,6 +13,8 @@ from megatron.core.utils import (
get_model_config
,
get_model_config
,
get_model_type
,
get_model_type
,
get_model_xattn
,
get_model_xattn
,
nvtx_range_pop
,
nvtx_range_push
,
)
)
from
megatron.core.pipeline_parallel.schedules
import
(
from
megatron.core.pipeline_parallel.schedules
import
(
forward_step
,
forward_step
,
...
@@ -82,10 +85,11 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -82,10 +85,11 @@ def forward_backward_pipelining_with_interleaving(
num_microbatches
:
int
,
num_microbatches
:
int
,
seq_length
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
=
None
,
decoder_seq_length
:
Optional
[
int
]
=
None
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
first_val_step
:
Optional
[
bool
]
=
None
,
adjust_tensor_shapes_fn
:
Optional
[
Callable
]
=
None
,
# unused
):
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
...
@@ -106,6 +110,9 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -106,6 +110,9 @@ def forward_backward_pipelining_with_interleaving(
assert
isinstance
(
assert
isinstance
(
data_iterator
,
list
data_iterator
,
list
),
"interleaved pipeline parallelism expected each model chunk to have a data iterator"
),
"interleaved pipeline parallelism expected each model chunk to have a data iterator"
assert
(
adjust_tensor_shapes_fn
is
None
),
"adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"
config
=
get_model_config
(
model
[
0
])
config
=
get_model_config
(
model
[
0
])
...
@@ -373,11 +380,8 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -373,11 +380,8 @@ def forward_backward_pipelining_with_interleaving(
def
forward_step_helper
(
def
forward_step_helper
(
virtual_microbatch_id
,
microbatch_id
,
checkpoint_activations_microbatch
virtual_microbatch_id
,
microbatch_id
,
checkpoint_activations_microbatch
):
):
"""Helper method to run forward step with model split into chunks
"""Helper method to run forward step with model split into chunks"""
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch param synchronization for next model chunk
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# Note: Asynchronous communication tends to slow down compute.
...
@@ -399,7 +403,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -399,7 +403,7 @@ def forward_backward_pipelining_with_interleaving(
)
)
# forward step
# forward step
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
...
@@ -427,6 +431,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -427,6 +431,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk
(
virtual_microbatch_id
),
is_first_microbatch_for_model_chunk
(
virtual_microbatch_id
),
),
),
current_microbatch
=
microbatch_id
,
current_microbatch
=
microbatch_id
,
vp_stage
=
model_chunk_id
,
)
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
@@ -443,13 +448,8 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -443,13 +448,8 @@ def forward_backward_pipelining_with_interleaving(
return
output_tensor
return
output_tensor
def
backward_step_helper
(
virtual_microbatch_id
):
def
backward_step_helper
(
virtual_microbatch_id
):
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks"""
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal
output_tensor_grads
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch grad synchronization (default)
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
...
@@ -459,7 +459,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -459,7 +459,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks
.
add
(
model_chunk_id
)
synchronized_model_chunks
.
add
(
model_chunk_id
)
# pylint: disable=E0606
# pylint: disable=E0606
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
...
@@ -509,7 +509,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -509,7 +509,7 @@ def forward_backward_pipelining_with_interleaving(
if
f_virtual_microbatch_id
is
not
None
:
if
f_virtual_microbatch_id
is
not
None
:
model_chunk_id
=
get_model_chunk_id
(
f_virtual_microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
f_virtual_microbatch_id
,
forward
=
True
)
f_model_chunk_id
=
model_chunk_id
f_model_chunk_id
=
model_chunk_id
f_context
=
VppContextManager
(
f_model_chunk_id
)
#
f_context = VppContextManager(f_model_chunk_id)
with
f_context
:
with
f_context
:
# launch param synchronization for next model chunk
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# Note: Asynchronous communication tends to slow down compute.
...
@@ -533,7 +533,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -533,7 +533,7 @@ def forward_backward_pipelining_with_interleaving(
)
)
# forward step
# forward step
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
...
@@ -556,7 +556,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -556,7 +556,7 @@ def forward_backward_pipelining_with_interleaving(
if
b_virtual_microbatch_id
is
not
None
:
if
b_virtual_microbatch_id
is
not
None
:
model_chunk_id
=
get_model_chunk_id
(
b_virtual_microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
b_virtual_microbatch_id
,
forward
=
False
)
b_model_chunk_id
=
model_chunk_id
b_model_chunk_id
=
model_chunk_id
b_context
=
VppContextManager
(
b_model_chunk_id
)
#
b_context = VppContextManager(b_model_chunk_id)
with
b_context
:
with
b_context
:
# launch grad synchronization (default)
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
...
@@ -565,7 +565,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -565,7 +565,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync
()
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
b_input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
b_input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
...
@@ -602,6 +602,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -602,6 +602,7 @@ def forward_backward_pipelining_with_interleaving(
),
),
),
),
current_microbatch
=
f_microbatch_id
,
current_microbatch
=
f_microbatch_id
,
vp_stage
=
f_model_chunk_id
,
)
)
# forward post process
# forward post process
...
@@ -675,8 +676,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -675,8 +676,6 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad
=
None
input_tensor_grad
=
None
if
f_virtual_microbatch_id
is
not
None
:
if
f_virtual_microbatch_id
is
not
None
:
# forward pass
# forward pass
forward_model_chunk_id
=
get_model_chunk_id
(
f_virtual_microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
pre_forward
is
not
None
:
if
pre_forward
is
not
None
:
pre_forward
()
pre_forward
()
...
@@ -689,8 +688,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -689,8 +688,6 @@ def forward_backward_pipelining_with_interleaving(
if
b_virtual_microbatch_id
is
not
None
:
if
b_virtual_microbatch_id
is
not
None
:
# Backward pass.
# Backward pass.
backward_model_chunk_id
=
get_model_chunk_id
(
b_virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
pre_backward
is
not
None
:
if
pre_backward
is
not
None
:
pre_backward
()
pre_backward
()
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
...
@@ -698,9 +695,15 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -698,9 +695,15 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
,
input_tensor_grad
return
output_tensor
,
input_tensor_grad
is_vp_first_stage
=
partial
(
parallel_state
.
is_pipeline_first_stage
,
ignore_virtual
=
False
)
is_vp_last_stage
=
partial
(
parallel_state
.
is_pipeline_last_stage
,
ignore_virtual
=
False
)
# Run warmup forward passes.
# Run warmup forward passes.
nvtx_range_push
(
suffix
=
"warmup"
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
,
is_vp_first_stage
(
vp_stage
=
0
))
)
fwd_wait_handles
=
None
fwd_wait_handles
=
None
fwd_wait_recv_handles
=
None
fwd_wait_recv_handles
=
None
...
@@ -727,10 +730,9 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -727,10 +730,9 @@ def forward_backward_pipelining_with_interleaving(
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
config
.
overlap_p2p_comm_warmup_flush
:
if
config
.
overlap_p2p_comm_warmup_flush
:
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
)
and
k
!=
0
:
if
not
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
)
and
k
!=
0
:
assert
recv_prev_wait_handles
,
(
assert
recv_prev_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, iteration
{
k
}
,'
f
'pp rank
{
pipeline_parallel_rank
}
, iteration
{
k
}
,'
'should have registered recv handle'
'should have registered recv handle'
...
@@ -777,7 +779,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -777,7 +779,7 @@ def forward_backward_pipelining_with_interleaving(
)
)
# Don't send tensor downstream if on last stage.
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
is_vp_last_stage
(
vp_stage
=
cur_model_chunk_id
):
output_tensor
=
None
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# Send and receive tensors as appropriate (send tensors computed
...
@@ -880,8 +882,10 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -880,8 +882,10 @@ def forward_backward_pipelining_with_interleaving(
if
recv_next
:
if
recv_next
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
nvtx_range_pop
(
suffix
=
"warmup"
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
nvtx_range_push
(
suffix
=
"steady"
)
for
k
in
range
(
num_microbatches_remaining
):
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
forward_k
=
k
+
num_warmup_microbatches
...
@@ -895,14 +899,15 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -895,14 +899,15 @@ def forward_backward_pipelining_with_interleaving(
else
:
else
:
checkpoint_activations_microbatch
=
None
checkpoint_activations_microbatch
=
None
microbatch_id
=
get_microbatch_id_in_model_chunk
(
forward_k
,
forward
=
True
)
if
config
.
overlap_p2p_comm
:
if
config
.
overlap_p2p_comm
:
# output send / receive sync
def
pp_pre_forward
():
def
pp_pre_forward
():
nonlocal
recv_prev_wait_handles
nonlocal
recv_prev_wait_handles
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
not
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
):
if
config
.
overlap_p2p_comm_warmup_flush
:
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_prev_wait_handles
,
(
assert
recv_prev_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, fwd iteration
{
forward_k
}
, '
f
'pp rank
{
pipeline_parallel_rank
}
, fwd iteration
{
forward_k
}
, '
...
@@ -917,7 +922,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -917,7 +922,6 @@ def forward_backward_pipelining_with_interleaving(
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# output async send / receive
def
pp_post_forward
(
output_tensor
):
def
pp_post_forward
(
output_tensor
):
nonlocal
send_next_wait_handle
nonlocal
send_next_wait_handle
nonlocal
fwd_recv_buffer
nonlocal
fwd_recv_buffer
...
@@ -927,10 +931,9 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -927,10 +931,9 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction,
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
# Last virtual stage no activation tensor to send.
# Last virtual stage no activation tensor to send.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
is_vp_last_stage
(
vp_stage
=
forward_model_chunk_id
):
output_tensor
=
None
output_tensor
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
@@ -963,8 +966,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -963,8 +966,6 @@ def forward_backward_pipelining_with_interleaving(
recv_prev_wait_handles
.
append
(
fwd_wait_handles
.
pop
(
"recv_prev"
))
recv_prev_wait_handles
.
append
(
fwd_wait_handles
.
pop
(
"recv_prev"
))
# assert fwd_wait_handles is not None
# assert fwd_wait_handles is not None
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensors
[
next_forward_model_chunk_id
].
append
(
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
]
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
]
...
@@ -973,14 +974,13 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -973,14 +974,13 @@ def forward_backward_pipelining_with_interleaving(
return
output_tensor
return
output_tensor
# Backward pass.
backward_k
=
k
backward_k
=
k
# grad send receive sync
# grad send receive sync
def
pp_pre_backward
():
def
pp_pre_backward
():
nonlocal
recv_next_wait_handles
nonlocal
recv_next_wait_handles
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
not
is_vp_last_stage
(
vp_stage
=
backward_model_chunk_id
):
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
config
.
overlap_p2p_comm_warmup_flush
:
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, bwd iteration
{
backward_k
}
, '
f
'pp rank
{
pipeline_parallel_rank
}
, bwd iteration
{
backward_k
}
, '
...
@@ -1000,11 +1000,9 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1000,11 +1000,9 @@ def forward_backward_pipelining_with_interleaving(
nonlocal
recv_next_wait_handles
nonlocal
recv_next_wait_handles
nonlocal
bwd_recv_buffer
nonlocal
bwd_recv_buffer
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
# First virtual stage no activation gradient tensor to send.
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
if
is_vp_first_stage
(
vp_stage
=
backward_model_chunk_id
):
input_tensor_grad
=
None
input_tensor_grad
=
None
recv_next
,
next_backward_model_chunk_id
=
recv_tensor_from_previous_stage
(
recv_next
,
next_backward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
@@ -1036,6 +1034,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1036,6 +1034,7 @@ def forward_backward_pipelining_with_interleaving(
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
]
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
]
)
)
bwd_recv_buffer
[(
backward_k
+
1
)
%
bwd_recv_buffer_size
]
=
None
bwd_recv_buffer
[(
backward_k
+
1
)
%
bwd_recv_buffer_size
]
=
None
return
input_tensor_grad
return
input_tensor_grad
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
...
@@ -1061,13 +1060,11 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1061,13 +1060,11 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction,
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
is_vp_last_stage
(
vp_stage
=
forward_model_chunk_id
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
output_tensor
=
None
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
is_vp_first_stage
(
vp_stage
=
backward_model_chunk_id
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
input_tensor_grad
=
None
input_tensor_grad
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
@@ -1104,8 +1101,11 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1104,8 +1101,11 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
nvtx_range_pop
(
suffix
=
"steady"
)
# Run cooldown backward passes (flush out pipeline).
# Run cooldown backward passes (flush out pipeline) for the last model chunk.
nvtx_range_push
(
suffix
=
"cooldown"
)
curr_vp_stage
=
config
.
virtual_pipeline_model_parallel_size
-
1
if
not
forward_only
:
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
if
bwd_wait_handles
is
not
None
:
for
bwd_wait_handle
in
bwd_wait_handles
.
values
():
for
bwd_wait_handle
in
bwd_wait_handles
.
values
():
...
@@ -1113,12 +1113,15 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1113,12 +1113,15 @@ def forward_backward_pipelining_with_interleaving(
if
are_all_microbatches_in_warmup
:
if
are_all_microbatches_in_warmup
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
)
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
,
is_last_stage
=
is_vp_last_stage
(
vp_stage
=
curr_vp_stage
),
)
)
)
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
not
is_vp_last_stage
(
vp_stage
=
cur_model_chunk_id
)
and
k
!=
0
:
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
and
k
!=
0
:
if
config
.
overlap_p2p_comm_warmup_flush
:
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, backward iteration
{
k
}
, '
f
'pp rank
{
pipeline_parallel_rank
}
, backward iteration
{
k
}
, '
...
@@ -1158,7 +1161,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1158,7 +1161,7 @@ def forward_backward_pipelining_with_interleaving(
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
b_virtual_microbatch_id
=
k
)
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
b_virtual_microbatch_id
=
k
)
# First virtual stage no activation gradient tensor to send.
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
):
input_tensor_grad
=
None
input_tensor_grad
=
None
if
config
.
overlap_p2p_comm_warmup_flush
:
if
config
.
overlap_p2p_comm_warmup_flush
:
...
@@ -1215,7 +1218,9 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1215,7 +1218,9 @@ def forward_backward_pipelining_with_interleaving(
if
model_chunk_id
not
in
synchronized_model_chunks
:
if
model_chunk_id
not
in
synchronized_model_chunks
:
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
synchronized_model_chunks
.
add
(
model_chunk_id
)
nvtx_range_pop
(
suffix
=
"cooldown"
)
nvtx_range_push
(
suffix
=
"misc"
)
assert
(
assert
(
not
recv_prev_wait_handles
not
recv_prev_wait_handles
),
'recv_prev_wait_handles should be cleared at the end of a step'
),
'recv_prev_wait_handles should be cleared at the end of a step'
...
@@ -1245,5 +1250,6 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1245,5 +1250,6 @@ def forward_backward_pipelining_with_interleaving(
if
hasattr
(
config
,
'enable_cuda_graph'
)
and
config
.
enable_cuda_graph
:
if
hasattr
(
config
,
'enable_cuda_graph'
)
and
config
.
enable_cuda_graph
:
create_cudagraphs
()
create_cudagraphs
()
nvtx_range_pop
(
suffix
=
"misc"
)
return
forward_data_store
return
forward_data_store
dcu_megatron/legacy/model/transformer.py
View file @
e103a256
...
@@ -87,6 +87,12 @@ def parallel_attention_init_wrapper(fn):
...
@@ -87,6 +87,12 @@ def parallel_attention_init_wrapper(fn):
return
wrapper
return
wrapper
class
ParallelAttentionPatch
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_context
=
None
,
encoder_output
=
None
,
inference_context
=
None
,
rotary_pos_emb
=
None
,
*
,
inference_params
=
None
):
rotary_pos_emb
=
None
,
*
,
inference_params
=
None
):
...
...
dcu_megatron/training/training.py
View file @
e103a256
This diff is collapsed.
Click to expand it.
pretrain_gpt.py
View file @
e103a256
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
"""Pretrain GPT."""
import
datetime
import
os
import
os
import
torch
import
torch
from
functools
import
partial
from
contextlib
import
nullcontext
import
inspect
from
functools
import
partial
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
megatron.core
import
parallel_state
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.training
import
inprocess_restart
from
megatron.training
import
print_rank_0
from
megatron.training
import
print_rank_0
from
megatron.training
import
get_timers
from
megatron.training
import
get_timers
from
megatron.training
import
get_tokenizer
from
megatron.training
import
get_tokenizer
from
megatron.core
import
mpu
from
megatron.core
import
mpu
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.datasets.blended_megatron_dataset_builder
import
BlendedMegatronDatasetBuilder
from
megatron.core.datasets.blended_megatron_dataset_builder
import
BlendedMegatronDatasetBuilder
from
megatron.core.datasets.gpt_dataset
import
GPTDatasetConfig
from
megatron.core.datasets.gpt_dataset
import
GPTDataset
,
GPTDatasetConfig
,
MockGPTDataset
from
megatron.core.datasets.gpt_dataset
import
MockGPTDataset
,
GPTDataset
from
megatron.core.enums
import
ModelType
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
import
megatron.legacy.model
from
megatron.core.models.gpt
import
GPTModel
from
megatron.core.models.gpt
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.core.models.gpt.gpt_layer_specs
import
(
from
megatron.core.utils
import
StragglerDetector
get_gpt_decoder_block_spec
,
get_gpt_layer_local_spec
,
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_mtp_block_spec
,
)
from
megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs
import
(
get_gpt_heterogeneous_layer_spec
,
)
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
megatron.core.transformer.spec_utils
import
import_module
from
megatron.core.transformer.spec_utils
import
import_module
from
megatron.core.utils
import
StragglerDetector
from
megatron.training
import
get_args
,
get_timers
,
get_tokenizer
,
pretrain
,
print_rank_0
from
megatron.training.arguments
import
core_transformer_config_from_args
from
megatron.training.utils
import
(
from
megatron.training.utils
import
(
get_batch_on_this_cp_rank
,
get_batch_on_this_cp_rank
,
get_batch_on_this_tp_rank
,
get_batch_on_this_tp_rank
,
get_blend_and_blend_per_split
,
get_blend_and_blend_per_split
,
)
)
from
megatron.training.arguments
import
core_transformer_config_from_args
from
megatron.training.yaml_arguments
import
core_transformer_config_from_yaml
from
megatron.training.yaml_arguments
import
core_transformer_config_from_yaml
from
megatron.core.models.gpt.gpt_layer_specs
import
(
get_gpt_decoder_block_spec
,
get_gpt_layer_local_spec
,
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_mtp_block_spec
,
)
from
dcu_megatron
import
megatron_adaptor
import
megatron.legacy.model
# isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
try
:
from
megatron.post_training.arguments
import
add_modelopt_args
,
modelopt_args_enabled
from
megatron.post_training.loss_func
import
loss_func
as
loss_func_modelopt
from
megatron.post_training.model_provider
import
model_provider
as
model_provider_modelopt
has_nvidia_modelopt
=
True
except
ImportError
:
has_nvidia_modelopt
=
False
from
dcu_megatron
import
megatron_adaptor
stimer
=
StragglerDetector
()
stimer
=
StragglerDetector
()
def
model_provider
(
pre_process
=
True
,
post_process
=
True
)
->
Union
[
GPTModel
,
megatron
.
legacy
.
model
.
GPTModel
]:
def
model_provider
(
pre_process
=
True
,
post_process
=
True
,
vp_stage
:
Optional
[
int
]
=
None
)
->
Union
[
GPTModel
,
megatron
.
legacy
.
model
.
GPTModel
]:
"""Builds the model.
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
...
@@ -55,24 +75,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -55,24 +75,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
"""
args
=
get_args
()
args
=
get_args
()
if
has_nvidia_modelopt
and
modelopt_args_enabled
(
args
):
# [ModelOpt]
return
model_provider_modelopt
(
pre_process
,
post_process
)
if
bool
(
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))):
if
bool
(
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))):
assert
args
.
transformer_impl
==
"transformer_engine"
assert
args
.
transformer_impl
==
"transformer_engine"
use_te
=
args
.
transformer_impl
==
"transformer_engine"
use_te
=
args
.
transformer_impl
==
"transformer_engine"
if
args
.
record_memory_history
:
if
args
.
record_memory_history
:
torch
.
cuda
.
memory
.
_record_memory_history
(
True
,
torch
.
cuda
.
memory
.
_record_memory_history
(
True
,
# keep 100,000 alloc/free events from before the snapshot
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries
=
100000
,
trace_alloc_max_entries
=
100000
,
# record stack information for the trace events
# record stack information for the trace events
trace_alloc_record_context
=
True
)
trace_alloc_record_context
=
True
,
)
def
oom_observer
(
device
,
alloc
,
device_alloc
,
device_free
):
def
oom_observer
(
device
,
alloc
,
device_alloc
,
device_free
):
# snapshot right after an OOM happened
# snapshot right after an OOM happened
print
(
'saving allocated state during OOM'
)
print
(
'saving allocated state during OOM'
)
snapshot
=
torch
.
cuda
.
memory
.
_snapshot
()
snapshot
=
torch
.
cuda
.
memory
.
_snapshot
()
from
pickle
import
dump
from
pickle
import
dump
dump
(
snapshot
,
open
(
f
"oom_rank-
{
torch
.
distributed
.
get_rank
()
}
_
{
args
.
memory_snapshot_path
}
"
,
'wb'
))
dump
(
snapshot
,
open
(
f
"oom_rank-
{
torch
.
distributed
.
get_rank
()
}
_
{
args
.
memory_snapshot_path
}
"
,
'wb'
),
)
torch
.
_C
.
_cuda_attach_out_of_memory_observer
(
oom_observer
)
torch
.
_C
.
_cuda_attach_out_of_memory_observer
(
oom_observer
)
...
@@ -91,27 +120,42 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -91,27 +120,42 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
,
post_process
=
post_process
,
)
)
else
:
# using core models
else
:
# using core models
if
args
.
spec
is
not
None
:
if
args
.
spec
is
not
None
:
transformer_layer_spec
=
import_module
(
args
.
spec
)
transformer_layer_spec
=
import_module
(
args
.
spec
)
else
:
else
:
if
args
.
num_experts
:
if
args
.
num_experts
:
# Define the decoder block spec
# Define the decoder block spec
transformer_layer_spec
=
get_gpt_decoder_block_spec
(
config
,
use_transformer_engine
=
use_te
,
normalization
=
args
.
normalization
)
transformer_layer_spec
=
get_gpt_decoder_block_spec
(
config
,
use_transformer_engine
=
use_te
,
normalization
=
args
.
normalization
,
qk_l2_norm
=
args
.
qk_l2_norm
,
vp_stage
=
vp_stage
)
elif
args
.
heterogeneous_layers_config_path
is
not
None
:
transformer_layer_spec
=
get_gpt_heterogeneous_layer_spec
(
config
,
use_te
)
else
:
else
:
# Define the decoder layer spec
# Define the decoder layer spec
if
use_te
:
if
use_te
:
transformer_layer_spec
=
get_gpt_layer_with_transformer_engine_spec
(
transformer_layer_spec
=
get_gpt_layer_with_transformer_engine_spec
(
args
.
num_experts
,
args
.
moe_grouped_gemm
,
args
.
num_experts
,
args
.
qk_layernorm
,
args
.
multi_latent_attention
,
args
.
moe_use_legacy_grouped_gemm
)
args
.
moe_grouped_gemm
,
args
.
qk_layernorm
,
args
.
multi_latent_attention
,
args
.
moe_use_legacy_grouped_gemm
,
qk_l2_norm
=
args
.
qk_l2_norm
)
else
:
else
:
transformer_layer_spec
=
get_gpt_layer_local_spec
(
transformer_layer_spec
=
get_gpt_layer_local_spec
(
args
.
num_experts
,
args
.
moe_grouped_gemm
,
args
.
num_experts
,
args
.
qk_layernorm
,
args
.
multi_latent_attention
,
args
.
moe_use_legacy_grouped_gemm
,
args
.
moe_grouped_gemm
,
normalization
=
args
.
normalization
)
args
.
qk_layernorm
,
args
.
multi_latent_attention
,
args
.
moe_use_legacy_grouped_gemm
,
normalization
=
args
.
normalization
,
)
mtp_block_spec
=
None
mtp_block_spec
=
None
if
args
.
mtp_num_layers
is
not
None
:
if
args
.
mtp_num_layers
is
not
None
:
mtp_block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_transformer_engine
=
use_te
)
mtp_block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_transformer_engine
=
use_te
,
vp_stage
=
vp_stage
)
model
=
GPTModel
(
model
=
GPTModel
(
config
=
config
,
config
=
config
,
...
@@ -128,16 +172,19 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -128,16 +172,19 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base
=
args
.
rotary_base
,
rotary_base
=
args
.
rotary_base
,
rope_scaling
=
args
.
use_rope_scaling
,
rope_scaling
=
args
.
use_rope_scaling
,
mtp_block_spec
=
mtp_block_spec
,
mtp_block_spec
=
mtp_block_spec
,
vp_stage
=
vp_stage
,
)
)
print_rank_0
(
model
)
return
model
return
model
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
"""Generate a batch."""
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
# TODO: this is pretty hacky, find a better way
if
(
not
mpu
.
is_pipeline_first_stage
())
and
(
not
mpu
.
is_pipeline_last_stage
()):
if
(
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
))
and
(
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
):
return
None
,
None
,
None
,
None
,
None
return
None
,
None
,
None
,
None
,
None
# get batches based on the TP rank you are on
# get batches based on the TP rank you are on
...
@@ -153,12 +200,15 @@ def get_batch(data_iterator):
...
@@ -153,12 +200,15 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR
=
10
SPIKY_LOSS_FACTOR
=
10
def
loss_func
(
loss_mask
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
):
def
loss_func
(
loss_mask
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
model
:
Optional
[
GPTModel
]
=
None
):
"""Loss function.
"""Loss function.
Args:
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns:
Returns:
the loss scalar for this micro-batch
the loss scalar for this micro-batch
...
@@ -168,57 +218,48 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...
@@ -168,57 +218,48 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""
"""
args
=
get_args
()
args
=
get_args
()
losses
=
output_tensor
.
float
()
if
has_nvidia_modelopt
and
modelopt_args_enabled
(
args
):
# [ModelOpt]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
return
loss_func_modelopt
(
loss_mask
,
output_tensor
,
model
=
model
)
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
if
args
.
context_parallel_size
>
1
:
losses
=
output_tensor
.
view
(
-
1
).
float
()
torch
.
distributed
.
all_reduce
(
loss
,
group
=
mpu
.
get_context_parallel_group
())
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
*
loss_mask
)
# Check individual rank losses are not NaN prior to DP all-reduce.
# Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine
=
get_rerun_state_machine
()
rerun_state_machine
=
get_rerun_state_machine
()
if
args
.
check_for_nan_in_loss_and_grad
:
if
args
.
check_for_nan_in_loss_and_grad
:
rerun_state_machine
.
validate_result
(
rerun_state_machine
.
validate_result
(
result
=
loss
[
0
]
,
result
=
loss
,
rejection_func
=
torch
.
isnan
,
rejection_func
=
torch
.
isnan
,
message
=
"found NaN in local forward loss calculation"
,
message
=
"found NaN in local forward loss calculation"
,
tolerance
=
0.0
,
# forward pass calculations are determinisic
tolerance
=
0.0
,
# forward pass calculations are determinisic
fatal
=
True
,
fatal
=
True
,
)
)
rerun_state_machine
.
validate_result
(
rerun_state_machine
.
validate_result
(
result
=
loss
[
0
]
,
result
=
loss
,
rejection_func
=
torch
.
isinf
,
rejection_func
=
torch
.
isinf
,
message
=
"found Inf in local forward loss calculation"
,
message
=
"found Inf in local forward loss calculation"
,
tolerance
=
0.0
,
# forward pass calculations are determinisic
tolerance
=
0.0
,
# forward pass calculations are determinisic
fatal
=
True
,
fatal
=
True
,
)
)
# Check for spiky loss
# Check for spiky loss
if
args
.
check_for_spiky_loss
:
if
args
.
check_for_spiky_loss
:
rerun_state_machine
.
validate_result
(
rerun_state_machine
.
validate_result
(
result
=
loss
[
0
]
,
result
=
loss
,
rejection_func
=
partial
(
rejection_func
=
partial
(
rerun_state_machine
.
is_unexpectedly_large
,
rerun_state_machine
.
is_unexpectedly_large
,
threshold
=
SPIKY_LOSS_FACTOR
,
threshold
=
SPIKY_LOSS_FACTOR
,
context
=
"loss"
,
context
=
"loss"
,
),
),
message
=
"Spiky loss"
,
message
=
"Spiky loss"
,
tolerance
=
0.0
,
# forward pass calculations are determinisic
tolerance
=
0.0
,
# forward pass calculations are determinisic
fatal
=
False
,
fatal
=
False
,
)
)
# Reduce loss for logging.
reporting_loss
=
loss
.
clone
().
detach
()
num_tokens
=
loss_mask
.
sum
().
clone
().
detach
().
to
(
torch
.
int
)
torch
.
distributed
.
all_reduce
(
reporting_loss
,
group
=
mpu
.
get_data_parallel_group
())
reporting_loss
=
torch
.
cat
([
loss
.
clone
().
detach
().
view
(
1
),
num_tokens
.
view
(
1
)])
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
return
(
loss
,
num_tokens
,
{
'lm loss'
:
reporting_loss
})
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this
local_num_tokens
=
loss
[
1
].
clone
().
detach
().
to
(
torch
.
int
)
return
(
loss
[
0
].
clone
(),
local_num_tokens
,
{
'lm loss'
:
(
reporting_loss
[
0
],
reporting_loss
[
1
])},
)
def
forward_step
(
data_iterator
,
model
:
GPTModel
):
def
forward_step
(
data_iterator
,
model
:
GPTModel
):
...
@@ -235,25 +276,26 @@ def forward_step(data_iterator, model: GPTModel):
...
@@ -235,25 +276,26 @@ def forward_step(data_iterator, model: GPTModel):
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
global
stimer
global
stimer
with
stimer
(
bdata
=
True
):
with
stimer
(
bdata
=
True
):
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
with
stimer
:
with
stimer
:
if
args
.
use_legacy_models
:
if
args
.
use_legacy_models
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
labels
=
labels
)
else
:
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
output_tensor
=
model
(
labels
=
labels
,
loss_mask
=
loss_mask
)
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
loss_mask
=
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return
output_tensor
,
partial
(
loss_func
,
loss_mask
,
model
=
model
)
def
is_dataset_built_on_rank
():
def
is_dataset_built_on_rank
():
return
(
return
(
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
()
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
)
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
or
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
)
and
parallel_state
.
get_tensor_model_parallel_rank
()
==
0
def
core_gpt_dataset_config_from_args
(
args
):
def
core_gpt_dataset_config_from_args
(
args
):
...
@@ -278,7 +320,8 @@ def core_gpt_dataset_config_from_args(args):
...
@@ -278,7 +320,8 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask
=
args
.
reset_attention_mask
,
reset_attention_mask
=
args
.
reset_attention_mask
,
eod_mask_loss
=
args
.
eod_mask_loss
,
eod_mask_loss
=
args
.
eod_mask_loss
,
create_attention_mask
=
args
.
create_attention_mask_in_dataloader
,
create_attention_mask
=
args
.
create_attention_mask_in_dataloader
,
s3_cache_path
=
args
.
s3_cache_path
,
object_storage_cache_path
=
args
.
object_storage_cache_path
,
mid_level_dataset_surplus
=
args
.
mid_level_dataset_surplus
,
)
)
...
@@ -300,10 +343,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -300,10 +343,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0
(
"> building train, validation, and test datasets for GPT ..."
)
print_rank_0
(
"> building train, validation, and test datasets for GPT ..."
)
train_ds
,
valid_ds
,
test_ds
=
BlendedMegatronDatasetBuilder
(
train_ds
,
valid_ds
,
test_ds
=
BlendedMegatronDatasetBuilder
(
dataset_type
,
dataset_type
,
train_val_test_num_samples
,
is_dataset_built_on_rank
,
config
train_val_test_num_samples
,
is_dataset_built_on_rank
,
config
).
build
()
).
build
()
print_rank_0
(
"> finished creating GPT datasets ..."
)
print_rank_0
(
"> finished creating GPT datasets ..."
)
...
@@ -316,10 +356,15 @@ if __name__ == "__main__":
...
@@ -316,10 +356,15 @@ if __name__ == "__main__":
# Temporary for transition to core datasets
# Temporary for transition to core datasets
train_valid_test_datasets_provider
.
is_distributed
=
True
train_valid_test_datasets_provider
.
is_distributed
=
True
# Optionally enable inprocess restart on pretrain
pretrain
,
store
=
inprocess_restart
.
maybe_wrap_for_inprocess_restart
(
pretrain
)
pretrain
(
pretrain
(
train_valid_test_datasets_provider
,
train_valid_test_datasets_provider
,
model_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
ModelType
.
encoder_or_decoder
,
forward_step
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
},
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
},
extra_args_provider
=
add_modelopt_args
if
has_nvidia_modelopt
else
None
,
store
=
store
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment