Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3629e611
Unverified
Commit
3629e611
authored
Dec 29, 2022
by
HELSON
Committed by
GitHub
Dec 29, 2022
Browse files
[example] update gpt benchmark (#2219)
parent
54de05da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
16 deletions
+56
-16
examples/language/gpt/run.sh
examples/language/gpt/run.sh
+5
-5
examples/language/gpt/train_gpt_demo.py
examples/language/gpt/train_gpt_demo.py
+51
-11
No files found.
examples/language/gpt/run.sh
View file @
3629e611
...
@@ -2,12 +2,12 @@
...
@@ -2,12 +2,12 @@
export
DISTPAN
=
"colossalai"
export
DISTPAN
=
"colossalai"
# The following options only valid when DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export
TPDEGREE
=
4
export
TPDEGREE
=
1
export
GPUNUM
=
8
export
GPUNUM
=
1
export
PLACEMENT
=
'c
pu
'
export
PLACEMENT
=
'c
onst
'
export
USE_SHARD_INIT
=
False
export
USE_SHARD_INIT
=
False
export
BATCH_SIZE
=
32
export
BATCH_SIZE
=
32
# export MODEL_TYPE="gpt2_
24
b"
# export MODEL_TYPE="gpt2_
10
b"
mkdir
-p
logs
mkdir
-p
logs
env
OMP_NUM_THREADS
=
16
torchrun
--standalone
--nproc_per_node
=
${
GPUNUM
}
train_gpt_demo.py
--tp_degree
=
${
TPDEGREE
}
--model_type
=
${
MODEL_TYPE
}
--batch_size
=
${
BATCH_SIZE
}
--placement
${
PLACEMENT
}
--shardinit
${
USE_SHARD_INIT
}
--distplan
${
DISTPAN
}
2>&1 |
tee
./logs/
${
MODEL_TYPE
}
_
${
DISTPAN
}
_gpu_
${
GPUNUM
}
_bs_
${
BATCH_SIZE
}
_tp_
${
TPDEGREE
}
.log
torchrun
--standalone
--nproc_per_node
=
${
GPUNUM
}
train_gpt_demo.py
--tp_degree
=
${
TPDEGREE
}
--model_type
=
${
MODEL_TYPE
}
--batch_size
=
${
BATCH_SIZE
}
--placement
${
PLACEMENT
}
--shardinit
${
USE_SHARD_INIT
}
--distplan
${
DISTPAN
}
2>&1 |
tee
./logs/
${
MODEL_TYPE
}
_
${
DISTPAN
}
_gpu_
${
GPUNUM
}
_bs_
${
BATCH_SIZE
}
_tp_
${
TPDEGREE
}
.log
examples/language/gpt/train_gpt_demo.py
View file @
3629e611
import
os
from
functools
import
partial
from
functools
import
partial
from
time
import
time
from
time
import
time
import
psutil
import
psutil
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
model_zoo
import
model_builder
from
packaging
import
version
from
packaging
import
version
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
@@ -15,7 +17,6 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, Proces
...
@@ -15,7 +17,6 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, Proces
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero.sharded_optim
import
LowLevelZeroOptimizer
from
colossalai.zero.sharded_optim
import
LowLevelZeroOptimizer
from
model_zoo
import
model_builder
def
parse_args
():
def
parse_args
():
...
@@ -88,7 +89,7 @@ class GPTLMLoss(nn.Module):
...
@@ -88,7 +89,7 @@ class GPTLMLoss(nn.Module):
return
self
.
loss_fn
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
self
.
loss_fn
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
#
# Randomly Generated Data
# Randomly Generated Data
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
@@ -111,6 +112,22 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
...
@@ -111,6 +112,22 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
def
get_model_size
(
model
:
nn
.
Module
):
total_numel
=
0
for
module
in
model
.
modules
():
for
p
in
module
.
parameters
(
recurse
=
False
):
total_numel
+=
p
.
numel
()
return
total_numel
def
set_cpu_maximum_parallelism
():
conf_str
=
torch
.
__config__
.
parallel_info
()
inter_str
=
conf_str
.
split
(
"hardware_concurrency() : "
)[
1
]
max_concurrency
=
inter_str
.
split
(
'
\n
'
)[
0
]
os
.
environ
[
"OMP_NUM_THREADS"
]
=
max_concurrency
print
(
f
"environmental variable OMP_NUM_THREADS is set to
{
max_concurrency
}
."
)
# Tensor Parallel
# Tensor Parallel
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
"""tensor_parallelize
"""tensor_parallelize
...
@@ -157,10 +174,10 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
...
@@ -157,10 +174,10 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
device
=
get_current_device
(),
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
pin_memory
=
True
,
hidden_dim
=
4096
,
hidden_dim
=
8192
,
search_range_mb
=
64
)
search_range_mb
=
64
)
if
placememt_policy
==
'const'
:
if
placememt_policy
==
'const'
:
model
.
gemini_manager
.
_placement_policy
.
set_const_memory_boundary
(
10
*
1024
)
model
.
gemini_manager
.
_placement_policy
.
set_const_memory_boundary
(
2
*
1024
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
...
@@ -176,6 +193,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
...
@@ -176,6 +193,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
def
main
():
def
main
():
set_cpu_maximum_parallelism
()
args
=
parse_args
()
args
=
parse_args
()
if
args
.
distplan
not
in
[
"colossalai"
,
"torch_ddp"
,
"torch_zero"
,
"zero1"
,
"zero2"
]:
if
args
.
distplan
not
in
[
"colossalai"
,
"torch_ddp"
,
"torch_zero"
,
"zero1"
,
"zero2"
]:
...
@@ -187,6 +205,9 @@ def main():
...
@@ -187,6 +205,9 @@ def main():
VOCAB_SIZE
=
50257
VOCAB_SIZE
=
50257
NUM_STEPS
=
10
NUM_STEPS
=
10
WARMUP_STEPS
=
1
assert
WARMUP_STEPS
<
NUM_STEPS
,
"warmup steps should smaller than the total steps"
assert
(
NUM_STEPS
-
WARMUP_STEPS
)
%
2
==
1
,
"the number of valid steps should be odd to take the median "
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
{})
colossalai
.
launch_from_torch
(
config
=
{})
...
@@ -239,7 +260,7 @@ def main():
...
@@ -239,7 +260,7 @@ def main():
verbose
=
True
)
verbose
=
True
)
# model is shared after TP
# model is shared after TP
numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()]
)
numel
=
get_model_size
(
model
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init model, '
),
ranks
=
[
0
])
logger
.
info
(
get_mem_info
(
prefix
=
'After init model, '
),
ranks
=
[
0
])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
...
@@ -249,29 +270,48 @@ def main():
...
@@ -249,29 +270,48 @@ def main():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
model
.
train
()
model
.
train
()
tflops_list
=
[]
for
n
in
range
(
NUM_STEPS
):
for
n
in
range
(
NUM_STEPS
):
# we just use randomly generated data here
# we just use randomly generated data here
input_ids
,
attn_mask
=
get_data
(
BATCH_SIZE
,
SEQ_LEN
,
VOCAB_SIZE
)
input_ids
,
attn_mask
=
get_data
(
BATCH_SIZE
,
SEQ_LEN
,
VOCAB_SIZE
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
start
=
time
()
start
=
time
()
outputs
=
model
(
input_ids
,
attn_mask
)
outputs
=
model
(
input_ids
,
attn_mask
)
loss
=
criterion
(
outputs
,
input_ids
)
loss
=
criterion
(
outputs
,
input_ids
)
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Forward '
),
ranks
=
[
0
])
torch
.
cuda
.
synchronize
()
fwd_end
=
time
()
fwd_time
=
fwd_end
-
start
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Forward '
),
ranks
=
[
0
])
if
args
.
distplan
in
[
"colossalai"
,
"zero1"
,
"zero2"
]:
if
args
.
distplan
in
[
"colossalai"
,
"zero1"
,
"zero2"
]:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
elif
args
.
distplan
in
[
"torch_ddp"
,
"torch_zero"
]:
elif
args
.
distplan
in
[
"torch_ddp"
,
"torch_zero"
]:
loss
.
backward
()
loss
.
backward
()
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Backward '
),
ranks
=
[
0
])
torch
.
cuda
.
synchronize
()
bwd_end
=
time
()
bwd_time
=
bwd_end
-
fwd_end
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Backward '
),
ranks
=
[
0
])
if
args
.
distplan
in
[
"zero1"
,
"zero2"
]:
if
args
.
distplan
in
[
"zero1"
,
"zero2"
]:
optimizer
.
sync_grad
()
optimizer
.
sync_grad
()
optimizer
.
step
()
optimizer
.
step
()
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Optimizer step '
),
ranks
=
[
0
])
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
optim_time
=
time
()
-
bwd_end
step_time
=
time
()
-
start
step_time
=
time
()
-
start
logger
.
info
(
logger
.
info
(
get_mem_info
(
prefix
=
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Optimizer step '
),
ranks
=
[
0
])
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Loss:
{
loss
.
item
():.
3
f
}
, Step time:
{
step_time
:.
3
f
}
s, TFLOPS:
{
get_tflops_func
(
step_time
):.
3
f
}
'
,
ranks
=
[
0
])
step_tflops
=
get_tflops_func
(
step_time
)
logger
.
info
(
f
"[
{
n
+
1
}
/
{
NUM_STEPS
}
] Loss:
{
loss
.
item
():.
3
f
}
, Step time:
{
step_time
:.
3
f
}
s, TFLOPS:
{
get_tflops_func
(
step_time
):.
3
f
}
, FWD time:
{
fwd_time
:.
3
f
}
s, BWD time:
{
bwd_time
:.
3
f
}
s, OPTIM time:
{
optim_time
:.
3
f
}
s"
,
ranks
=
[
0
],
)
if
n
>=
WARMUP_STEPS
:
tflops_list
.
append
(
step_tflops
)
tflops_list
.
sort
()
median_index
=
((
NUM_STEPS
-
WARMUP_STEPS
)
>>
1
)
+
WARMUP_STEPS
logger
.
info
(
f
"Median TFLOPS is
{
tflops_list
[
median_index
]:.
3
f
}
"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment