Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
eb4333f0
Commit
eb4333f0
authored
May 14, 2025
by
wangxj
Browse files
添加reproduce功能, 开启后可实现训练loss完全复现, 少量降低训练性能
parent
57944e55
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
2 deletions
+64
-2
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+4
-0
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+9
-0
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+36
-1
examples/llama/hostfile
examples/llama/hostfile
+1
-0
examples/llama/train_llama2_7b_1nodes.sh
examples/llama/train_llama2_7b_1nodes.sh
+14
-1
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
eb4333f0
...
@@ -191,6 +191,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -191,6 +191,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
from
..training.training
import
train
from
..training.initialize
import
_set_random_seed
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
build_tokenizer
)
...
@@ -200,6 +201,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -200,6 +201,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# remove fused_kernels
# remove fused_kernels
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
_compile_dependencies
)
# 添加固定seed
MegatronAdaptation
.
register
(
'megatron.training.initialize._set_random_seed'
,
_set_random_seed
)
# add trace_handler
# add trace_handler
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
...
...
dcu_megatron/training/arguments.py
View file @
eb4333f0
...
@@ -24,6 +24,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
...
@@ -24,6 +24,7 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
# add extra arguments
# add extra arguments
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_initialization_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
...
@@ -101,6 +102,14 @@ def _add_extra_training_args(parser):
...
@@ -101,6 +102,14 @@ def _add_extra_training_args(parser):
return
parser
return
parser
def
_add_extra_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'extra initialization args'
)
group
.
add_argument
(
'--reproduce'
,
action
=
'store_true'
,
help
=
'reproduce train loss, need set --seed > 0.'
)
return
parser
def
_add_extra_tokenizer_args
(
parser
):
def
_add_extra_tokenizer_args
(
parser
):
# 删除原参数
# 删除原参数
remove_original_params
(
parser
,
[
"tokenizer_type"
])
remove_original_params
(
parser
,
[
"tokenizer_type"
])
...
...
dcu_megatron/training/initialize.py
View file @
eb4333f0
"""Megatron initialization."""
"""Megatron initialization."""
import
time
import
time
import
torch
import
torch
import
random
import
numpy
as
np
from
datetime
import
timedelta
from
datetime
import
timedelta
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
,
tensor_parallel
def
_compile_dependencies
():
def
_compile_dependencies
():
...
@@ -149,3 +151,36 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -149,3 +151,36 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f
"> initialized pipeline model parallel with size "
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
)
def
_set_random_seed
(
seed_
:
int
,
data_parallel_random_init
:
bool
=
False
,
te_rng_tracker
:
bool
=
False
,
inference_rng_tracker
:
bool
=
False
,
use_cudagraphable_rng
:
bool
=
False
,
):
"""Set random seed for reproducability."""
args
=
get_args
()
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
# 设置cpu随机种子
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
# 设置gpu随机种子
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
,
te_rng_tracker
,
inference_rng_tracker
,
use_cudagraphable_rng
)
if
args
.
reproduce
:
assert
(
args
.
attention_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.attention_dropout =
{
args
.
attention_dropout
}
must be set to 0."
assert
(
args
.
hidden_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.hidden_dropout =
{
args
.
hidden_dropout
}
must be set to 0."
torch
.
backends
.
cudnn
.
deterministic
=
True
# 设置cudnn后端为确定性算法
torch
.
backends
.
cudnn
.
benchmark
=
False
# 固定卷积算法
torch
.
use_deterministic_algorithms
(
True
)
# 使用torch的deterministic算子 避免不确定性
else
:
raise
ValueError
(
"Seed ({}) should be a positive integer."
.
format
(
seed_
))
examples/llama/hostfile
0 → 100644
View file @
eb4333f0
node036 slots=8
\ No newline at end of file
examples/llama/train_llama2_7b_1nodes.sh
View file @
eb4333f0
#!/bin/bash
#!/bin/bash
INITIALIZATION_ARGS
=(
--num-workers
2
)
for
para
in
$*
for
para
in
$*
do
do
if
[[
$para
==
--data_path
*
]]
;
then
if
[[
$para
==
--data_path
*
]]
;
then
...
@@ -10,6 +12,16 @@ do
...
@@ -10,6 +12,16 @@ do
checkpoint_path
=
${
para
#*=
}
checkpoint_path
=
${
para
#*=
}
elif
[[
$para
==
--profiling
*
]]
;
then
elif
[[
$para
==
--profiling
*
]]
;
then
profiling
=
${
para
#*=
}
profiling
=
${
para
#*=
}
elif
[[
$para
==
--reproduce
*
]]
;
then
INITIALIZATION_ARGS
=(
--reproduce
--num-workers
0
)
export
MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC
=
1
# miopen 确定算法打开
export
ROCBLAS_ATOMICS_MOD
=
0
# rocblas 关闭原子操作
# 关闭miopen中的atomic操作算法, 只保留gemm算法
export
MIOPEN_DEBUG_CONV_FFT
=
0
export
MIOPEN_DEBUG_CONV_DIRECT
=
0
export
MIOPEN_DEBUG_CONV_GEMM
=
1
export
MIOPEN_DEBUG_CONV_WINOGRAD
=
0
export
MIOPEN_DEBUG_CONV_IMPLICIT_GEMM
=
0
fi
fi
done
done
...
@@ -63,7 +75,7 @@ TRAINING_ARGS=(
...
@@ -63,7 +75,7 @@ TRAINING_ARGS=(
--use-legacy-models
--use-legacy-models
--micro-batch-size
1
--micro-batch-size
1
--global-batch-size
256
--global-batch-size
256
--train-iters
1
0
--train-iters
5
0
--weight-decay
0.1
--weight-decay
0.1
--adam-beta1
0.9
--adam-beta1
0.9
--adam-beta2
0.95
--adam-beta2
0.95
...
@@ -134,6 +146,7 @@ APP="python -u ${MEGATRON_PATH}/pretrain_gpt.py \
...
@@ -134,6 +146,7 @@ APP="python -u ${MEGATRON_PATH}/pretrain_gpt.py \
${
DATA_ARGS
[@]
}
\
${
DATA_ARGS
[@]
}
\
${
EVAL_AND_LOGGING_ARGS
[@]
}
\
${
EVAL_AND_LOGGING_ARGS
[@]
}
\
${
DISTRIBUTED_ARGS
[@]
}
\
${
DISTRIBUTED_ARGS
[@]
}
\
${
INITIALIZATION_ARGS
[@]
}
\
"
"
if
[[
$profiling
==
"torch"
]]
;
then
if
[[
$profiling
==
"torch"
]]
;
then
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment