Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6e0faa70
Unverified
Commit
6e0faa70
authored
Jan 31, 2023
by
HELSON
Committed by
GitHub
Jan 31, 2023
Browse files
[gemini] add profiler in the demo (#2534)
parent
df437ca0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
3 deletions
+44
-3
examples/language/gpt/gemini/commons/utils.py
examples/language/gpt/gemini/commons/utils.py
+29
-0
examples/language/gpt/gemini/train_gpt_demo.py
examples/language/gpt/gemini/train_gpt_demo.py
+15
-3
No files found.
examples/language/gpt/gemini/commons/utils.py
View file @
6e0faa70
import
time
from
contextlib
import
nullcontext
import
torch
from
torch.profiler
import
ProfilerActivity
,
profile
,
schedule
,
tensorboard_trace_handler
class
DummyProfiler
:
def
__init__
(
self
):
self
.
step_number
=
0
def
step
(
self
):
self
.
step_number
+=
1
# Randomly Generated Data
...
...
@@ -10,3 +23,19 @@ def get_data(batch_size, seq_len, vocab_size):
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
def
get_profile_context
(
enable_flag
,
warmup_steps
,
active_steps
,
save_dir
):
if
enable_flag
:
return
profile
(
activities
=
[
ProfilerActivity
.
CPU
,
ProfilerActivity
.
CUDA
],
schedule
=
schedule
(
wait
=
0
,
warmup
=
warmup_steps
,
active
=
active_steps
),
on_trace_ready
=
tensorboard_trace_handler
(
save_dir
),
record_shapes
=
True
,
profile_memory
=
True
)
else
:
return
nullcontext
(
DummyProfiler
())
def
get_time_stamp
():
cur_time
=
time
.
strftime
(
"%d-%H:%M"
,
time
.
localtime
())
return
cur_time
examples/language/gpt/gemini/train_gpt_demo.py
View file @
6e0faa70
...
...
@@ -6,7 +6,7 @@ import psutil
import
torch
import
torch.nn
as
nn
from
commons.model_zoo
import
model_builder
from
commons.utils
import
get_data
,
get_
tflops
from
commons.utils
import
get_data
,
get_
profile_context
,
get_tflops
,
get_time_stamp
from
packaging
import
version
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -201,7 +201,8 @@ def main():
WARMUP_STEPS
=
1
assert
WARMUP_STEPS
<
NUM_STEPS
,
"warmup steps should smaller than the total steps"
assert
(
NUM_STEPS
-
WARMUP_STEPS
)
%
2
==
1
,
"the number of valid steps should be odd to take the median "
assert
(
NUM_STEPS
-
WARMUP_STEPS
)
%
2
==
1
,
"the number of valid steps should be odd to take the median"
PROF_FLAG
=
False
# The flag of profiling, False by default
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
{})
...
...
@@ -292,7 +293,8 @@ def main():
torch
.
cuda
.
synchronize
()
model
.
train
()
tflops_list
=
[]
for
n
in
range
(
NUM_STEPS
):
def
train_step
():
# we just use randomly generated data here
input_ids
,
attn_mask
=
get_data
(
BATCH_SIZE
,
SEQ_LEN
,
VOCAB_SIZE
)
optimizer
.
zero_grad
()
...
...
@@ -331,6 +333,16 @@ def main():
if
n
>=
WARMUP_STEPS
:
tflops_list
.
append
(
step_tflops
)
demo_profiler
=
get_profile_context
(
PROF_FLAG
,
WARMUP_STEPS
,
NUM_STEPS
-
WARMUP_STEPS
,
save_dir
=
f
"profile/
{
get_time_stamp
()
}
-demo"
)
with
demo_profiler
as
prof
:
for
n
in
range
(
NUM_STEPS
):
train_step
()
prof
.
step
()
tflops_list
.
sort
()
median_index
=
((
NUM_STEPS
-
WARMUP_STEPS
)
>>
1
)
+
WARMUP_STEPS
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment