Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1cc3fbe9
Commit
1cc3fbe9
authored
Mar 26, 2025
by
silencealiang
Browse files
add hip profiler
parent
2757c9c9
Pipeline
#2579
passed with stage
Changes
6
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
2 deletions
+59
-2
examples/gpt3/train_gpt_567B_1nodes.sh
examples/gpt3/train_gpt_567B_1nodes.sh
+12
-0
examples/gpt3/train_gpt_567B_multinodes.sh
examples/gpt3/train_gpt_567B_multinodes.sh
+12
-0
examples/mixtral/train_mixtral_8x7B_1nodes.sh
examples/mixtral/train_mixtral_8x7B_1nodes.sh
+12
-0
examples/mixtral/train_mixtral_8x7B_multinodes.sh
examples/mixtral/train_mixtral_8x7B_multinodes.sh
+12
-0
megatron/training/arguments.py
megatron/training/arguments.py
+5
-2
megatron/training/training.py
megatron/training/training.py
+6
-0
No files found.
examples/gpt3/train_gpt_567B_1nodes.sh
View file @
1cc3fbe9
...
@@ -113,6 +113,14 @@ TORCH_PROFIE_ARGS=(
...
@@ -113,6 +113,14 @@ TORCH_PROFIE_ARGS=(
--use-pytorch-profiler
--use-pytorch-profiler
)
)
HIP_PROFIE_ARGS
=(
--profile
--profile-ranks
0 1 2 3 4 5 6 7
--profile-step-start
4
--profile-step-end
5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS
=(
MODEL_PARALLEL_ARGS
=(
--tensor-model-parallel-size
2
--tensor-model-parallel-size
2
--pipeline-model-parallel-size
1
--pipeline-model-parallel-size
1
...
@@ -154,6 +162,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
...
@@ -154,6 +162,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
if
[[
$profiling
==
"torch"
]]
;
then
if
[[
$profiling
==
"torch"
]]
;
then
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
elif
[[
$profiling
==
"hip"
]]
;
then
mkdir
-p
hip_prof_data
APP+
=
"
${
HIP_PROFIE_ARGS
[@]
}
"
APP
=
"hipprof -d hip_prof_data --hip-trace --trace-off
${
APP
}
"
fi
fi
#for hygon cpu
#for hygon cpu
...
...
examples/gpt3/train_gpt_567B_multinodes.sh
View file @
1cc3fbe9
...
@@ -113,6 +113,14 @@ TORCH_PROFIE_ARGS=(
...
@@ -113,6 +113,14 @@ TORCH_PROFIE_ARGS=(
--use-pytorch-profiler
--use-pytorch-profiler
)
)
HIP_PROFIE_ARGS
=(
--profile
--profile-ranks
0 1 2 3 4 5 6 7
--profile-step-start
4
--profile-step-end
5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS
=(
MODEL_PARALLEL_ARGS
=(
--tensor-model-parallel-size
2
--tensor-model-parallel-size
2
--pipeline-model-parallel-size
16
--pipeline-model-parallel-size
16
...
@@ -155,6 +163,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
...
@@ -155,6 +163,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
if
[[
$profiling
==
"torch"
]]
;
then
if
[[
$profiling
==
"torch"
]]
;
then
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
elif
[[
$profiling
==
"hip"
]]
;
then
mkdir
-p
hip_prof_data
APP+
=
"
${
HIP_PROFIE_ARGS
[@]
}
"
APP
=
"hipprof -d hip_prof_data --hip-trace --trace-off
${
APP
}
"
fi
fi
#for hygon cpu
#for hygon cpu
...
...
examples/mixtral/train_mixtral_8x7B_1nodes.sh
View file @
1cc3fbe9
...
@@ -116,6 +116,14 @@ TORCH_PROFIE_ARGS=(
...
@@ -116,6 +116,14 @@ TORCH_PROFIE_ARGS=(
--use-pytorch-profiler
--use-pytorch-profiler
)
)
HIP_PROFIE_ARGS
=(
--profile
--profile-ranks
0 1 2 3 4 5 6 7
--profile-step-start
4
--profile-step-end
5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS
=(
MODEL_PARALLEL_ARGS
=(
--tensor-model-parallel-size
2
--tensor-model-parallel-size
2
--pipeline-model-parallel-size
1
--pipeline-model-parallel-size
1
...
@@ -157,6 +165,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
...
@@ -157,6 +165,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
if
[[
$profiling
==
"torch"
]]
;
then
if
[[
$profiling
==
"torch"
]]
;
then
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
elif
[[
$profiling
==
"hip"
]]
;
then
mkdir
-p
hip_prof_data
APP+
=
"
${
HIP_PROFIE_ARGS
[@]
}
"
APP
=
"hipprof -d hip_prof_data --hip-trace --trace-off
${
APP
}
"
fi
fi
#for hygon cpu
#for hygon cpu
...
...
examples/mixtral/train_mixtral_8x7B_multinodes.sh
View file @
1cc3fbe9
...
@@ -116,6 +116,14 @@ TORCH_PROFIE_ARGS=(
...
@@ -116,6 +116,14 @@ TORCH_PROFIE_ARGS=(
--use-pytorch-profiler
--use-pytorch-profiler
)
)
HIP_PROFIE_ARGS
=(
--profile
--profile-ranks
0 1 2 3 4 5 6 7
--profile-step-start
4
--profile-step-end
5
--use-hip-profiler
)
MODEL_PARALLEL_ARGS
=(
MODEL_PARALLEL_ARGS
=(
--tensor-model-parallel-size
2
--tensor-model-parallel-size
2
--pipeline-model-parallel-size
4
--pipeline-model-parallel-size
4
...
@@ -157,6 +165,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
...
@@ -157,6 +165,10 @@ APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py \
if
[[
$profiling
==
"torch"
]]
;
then
if
[[
$profiling
==
"torch"
]]
;
then
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
APP+
=
"
${
TORCH_PROFIE_ARGS
[@]
}
"
elif
[[
$profiling
==
"hip"
]]
;
then
mkdir
-p
hip_prof_data
APP+
=
"
${
HIP_PROFIE_ARGS
[@]
}
"
APP
=
"hipprof -d hip_prof_data --hip-trace --trace-off
${
APP
}
"
fi
fi
#for hygon cpu
#for hygon cpu
...
...
megatron/training/arguments.py
View file @
1cc3fbe9
...
@@ -1408,10 +1408,13 @@ def _add_training_args(parser):
...
@@ -1408,10 +1408,13 @@ def _add_training_args(parser):
help
=
'Use the built-in pytorch profiler. '
help
=
'Use the built-in pytorch profiler. '
'Useful if you wish to view profiles in tensorboard.'
,
'Useful if you wish to view profiles in tensorboard.'
,
dest
=
'use_pytorch_profiler'
)
dest
=
'use_pytorch_profiler'
)
group
.
add_argument
(
'--profile-ranks'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
0
],
group
.
add_argument
(
'--use-hip-profiler'
,
action
=
'store_true'
,
help
=
'Global ranks to profile.'
)
help
=
'Use HIP PROFILER'
,
dest
=
'use_hip_profiler'
)
group
.
add_argument
(
'--profile-dir'
,
type
=
str
,
default
=
"./"
,
group
.
add_argument
(
'--profile-dir'
,
type
=
str
,
default
=
"./"
,
help
=
'profile dir to save.'
)
help
=
'profile dir to save.'
)
group
.
add_argument
(
'--profile-ranks'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
0
],
help
=
'Global ranks to profile.'
)
group
.
add_argument
(
'--record-memory-history'
,
action
=
"store_true"
,
default
=
False
,
group
.
add_argument
(
'--record-memory-history'
,
action
=
"store_true"
,
default
=
False
,
help
=
'Record memory history in last rank.'
)
help
=
'Record memory history in last rank.'
)
group
.
add_argument
(
'--memory-snapshot-path'
,
type
=
str
,
default
=
"snapshot.pickle"
,
group
.
add_argument
(
'--memory-snapshot-path'
,
type
=
str
,
default
=
"snapshot.pickle"
,
...
...
megatron/training/training.py
View file @
1cc3fbe9
...
@@ -1519,6 +1519,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1519,6 +1519,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready
=
trace_handler
)
on_trace_ready
=
trace_handler
)
prof
.
start
()
prof
.
start
()
elif
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_hip_profiler
:
import
ctypes
roctracer
=
ctypes
.
cdll
.
LoadLibrary
(
"/opt/dtk/roctracer/lib/libroctracer64.so"
)
start_iteration
=
iteration
start_iteration
=
iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
...
@@ -1543,6 +1546,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1543,6 +1546,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
if
args
.
use_pytorch_profiler
:
if
args
.
use_pytorch_profiler
:
prof
.
step
()
prof
.
step
()
elif
args
.
use_hip_profiler
:
if
iteration
==
args
.
profile_step_start
:
roctracer
.
roctracer_start
()
if
iteration
==
args
.
profile_step_end
:
roctracer
.
roctracer_stop
()
elif
iteration
==
args
.
profile_step_start
:
elif
iteration
==
args
.
profile_step_start
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment