Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a82da26f
Unverified
Commit
a82da26f
authored
Apr 25, 2022
by
Frank Lee
Committed by
GitHub
Apr 25, 2022
Browse files
[cli] refactored micro-benchmarking cli and added more metrics (#858)
parent
ee222dfb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
290 additions
and
263 deletions
+290
-263
colossalai/cli/benchmark/__init__.py
colossalai/cli/benchmark/__init__.py
+27
-1
colossalai/cli/benchmark/benchmark.py
colossalai/cli/benchmark/benchmark.py
+94
-0
colossalai/cli/benchmark/models.py
colossalai/cli/benchmark/models.py
+17
-0
colossalai/cli/benchmark/run.py
colossalai/cli/benchmark/run.py
+0
-86
colossalai/cli/benchmark/simple_model.py
colossalai/cli/benchmark/simple_model.py
+0
-19
colossalai/cli/benchmark/utils.py
colossalai/cli/benchmark/utils.py
+151
-143
colossalai/cli/cli.py
colossalai/cli/cli.py
+1
-14
No files found.
colossalai/cli/benchmark/__init__.py
View file @
a82da26f
from
random
import
choices
import
click
from
.utils
import
*
from
.run
import
*
from
.benchmark
import
run_benchmark
from
colossalai.context
import
Config
__all__
=
[
'benchmark'
]
@
click
.
command
()
@
click
.
option
(
"-g"
,
"--gpus"
,
type
=
int
,
default
=
None
,
help
=
"Total number of devices to use."
)
@
click
.
option
(
"-b"
,
"--batch_size"
,
type
=
int
,
default
=
8
,
help
=
"Batch size of the input tensor."
)
@
click
.
option
(
"-s"
,
"--seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Sequence length of the input tensor."
)
@
click
.
option
(
"-d"
,
"--dimension"
,
type
=
int
,
default
=
1024
,
help
=
"Hidden dimension of the input tensor."
)
@
click
.
option
(
"-w"
,
"--warmup_steps"
,
type
=
int
,
default
=
10
,
help
=
"The number of warmup steps."
)
@
click
.
option
(
"-p"
,
"--profile_steps"
,
type
=
int
,
default
=
50
,
help
=
"The number of profiling steps."
)
@
click
.
option
(
"-l"
,
"--layers"
,
type
=
int
,
default
=
2
)
@
click
.
option
(
"-m"
,
"--model"
,
type
=
click
.
Choice
([
'mlp'
],
case_sensitive
=
False
),
default
=
'mlp'
,
help
=
"Select the model to benchmark, currently only supports MLP"
)
def
benchmark
(
gpus
:
int
,
batch_size
:
int
,
seq_len
:
int
,
dimension
:
int
,
warmup_steps
:
int
,
profile_steps
:
int
,
layers
:
int
,
model
:
str
):
args_dict
=
locals
()
args
=
Config
(
args_dict
)
run_benchmark
(
args
)
colossalai/cli/benchmark/benchmark.py
0 → 100644
View file @
a82da26f
import
colossalai
import
click
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
typing
import
List
,
Dict
from
colossalai.context
import
Config
from
colossalai.context.random
import
reset_seeds
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
free_port
,
MultiTimer
from
colossalai.cli.benchmark.utils
import
find_all_configs
,
profile_model
,
get_batch_data
from
.models
import
MLP
def
run_benchmark
(
args
:
Config
)
->
None
:
"""
Run benchmarking with torch.multiprocessing.
"""
# sanity checks
if
args
.
gpus
is
None
:
click
.
echo
(
"Error: --num_gpus is not given"
)
exit
()
click
.
echo
(
"=== Benchmarking Parameters ==="
)
for
k
,
v
in
args
.
items
():
click
.
echo
(
f
'
{
k
}
:
{
v
}
'
)
click
.
echo
(
''
)
config_list
=
find_all_configs
(
args
.
gpus
)
avail_ports
=
[
free_port
()
for
_
in
range
(
len
(
config_list
))]
run_func
=
partial
(
run_dist_profiling
,
world_size
=
args
.
gpus
,
port_list
=
avail_ports
,
config_list
=
config_list
,
hyperparams
=
args
)
mp
.
spawn
(
run_func
,
nprocs
=
args
.
gpus
)
def
run_dist_profiling
(
rank
:
int
,
world_size
:
int
,
port_list
:
List
[
int
],
config_list
:
List
[
Dict
],
hyperparams
:
Config
)
->
None
:
"""
A function executed for profiling, this function should be spawn by torch.multiprocessing.
Args:
rank (int): rank of the process
world_size (int): the number of processes
port_list (List[int]): a list of free ports for initializing distributed networks
config_list (List[Dict]): a list of configuration
hyperparams (Config): the hyperparameters given by the user
"""
# disable logging for clean output
disable_existing_loggers
()
logger
=
get_dist_logger
()
logger
.
set_level
(
'WARNING'
)
for
config
,
port
in
zip
(
config_list
,
port_list
):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
timer
=
MultiTimer
()
if
hyperparams
.
model
==
'mlp'
:
model
=
MLP
(
dim
=
hyperparams
.
dimension
,
layers
=
hyperparams
.
layers
)
else
:
if
gpc
.
get_global_rank
()
==
0
:
click
.
echo
(
"Error: Invalid argument for --model"
)
exit
()
data_func
=
partial
(
get_batch_data
,
dim
=
hyperparams
.
dimension
,
batch_size
=
hyperparams
.
batch_size
,
seq_length
=
hyperparams
.
seq_len
,
mode
=
config
.
parallel
.
tensor
.
mode
)
fwd_time
,
bwd_time
,
max_allocated
,
max_cached
=
profile_model
(
model
=
model
,
warmup_steps
=
hyperparams
.
warmup_steps
,
profile_steps
=
hyperparams
.
profile_steps
,
data_func
=
data_func
,
timer
=
timer
)
gpc
.
destroy
()
reset_seeds
()
if
gpc
.
get_global_rank
()
==
0
:
config_str
=
', '
.
join
([
f
'
{
k
}
:
{
v
}
'
for
k
,
v
in
config
.
parallel
.
tensor
.
items
()])
click
.
echo
(
f
"===
{
config_str
}
==="
)
click
.
echo
(
f
"Average forward time:
{
fwd_time
}
"
)
click
.
echo
(
f
"Average backward time:
{
bwd_time
}
"
)
click
.
echo
(
f
"Max allocated GPU memory:
{
max_allocated
}
"
)
click
.
echo
(
f
"Max cached GPU memory:
{
max_cached
}
\n
"
)
colossalai/cli/benchmark/models.py
0 → 100644
View file @
a82da26f
import
torch
import
colossalai.nn
as
col_nn
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
layers
:
int
):
super
().
__init__
()
self
.
layers
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
layers
):
self
.
layers
.
append
(
col_nn
.
Linear
(
dim
,
dim
))
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
colossalai/cli/benchmark/run.py
deleted
100644 → 0
View file @
ee222dfb
import
torch
import
inspect
import
os
import
subprocess
import
sys
from
colossalai.initialize
import
launch_from_torch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
print_rank_0
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.cli.benchmark
import
build_args_parser
,
build_configs
,
\
build_input_tensor
,
profile_1d
,
profile_2d
,
profile_2p5d
,
profile_3d
,
\
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
,
ITER_TIMES
def
launch
(
args
=
None
):
train_script
=
inspect
.
getfile
(
inspect
.
currentframe
())
assert
args
is
not
None
,
"args should not be None"
env
=
os
.
environ
.
copy
()
if
args
.
num_gpus
==
-
1
or
args
.
num_gpus
>
torch
.
cuda
.
device_count
():
nproc_per_node
=
torch
.
cuda
.
device_count
()
else
:
nproc_per_node
=
args
.
num_gpus
train_args
=
[
f
"--num_gpus=
{
nproc_per_node
}
"
]
if
args
.
bs
!=
BATCH_SIZE
:
train_args
.
append
(
f
"--bs=
{
args
.
bs
}
"
)
if
args
.
hid_dim
!=
HIDDEN_DIM
:
train_args
.
append
(
f
"--hid_dim=
{
args
.
hid_dim
}
"
)
if
args
.
num_steps
!=
ITER_TIMES
:
train_args
.
append
(
f
"--num_steps=
{
args
.
num_steps
}
"
)
if
args
.
seq_len
!=
SEQ_LENGTH
:
train_args
.
append
(
f
"--seq_len=
{
args
.
seq_len
}
"
)
master_port
=
free_port
()
if
torch
.
__version__
<=
"1.09"
:
cmd
=
[
sys
.
executable
,
"-u"
,
"-m"
,
"torch.distributed.launch"
,
f
"--nproc_per_node=
{
nproc_per_node
}
"
,
f
"--master_port=
{
master_port
}
"
]
+
[
train_script
]
+
train_args
else
:
cmd
=
[
"torchrun"
,
f
"--nproc_per_node=
{
nproc_per_node
}
"
,
f
"--master_port=
{
master_port
}
"
]
+
[
train_script
]
+
train_args
result
=
subprocess
.
Popen
(
cmd
,
env
=
env
)
result
.
wait
()
if
result
.
returncode
>
0
:
sys
.
exit
(
result
.
returncode
)
def
main
():
parser
=
build_args_parser
()
args
=
parser
.
parse_args
()
disable_existing_loggers
()
logger
=
get_dist_logger
()
launch_from_torch
(
config
=
{},
verbose
=
False
)
input_tensor
=
build_input_tensor
(
args
)
config_dict
=
build_configs
(
args
)
if
len
(
config_dict
)
==
0
:
print_rank_0
(
f
"WARNING: We need at least two devices to profile TP strategies performance."
)
gpc
.
destroy
()
return
for
parallel_mode
,
config
in
config_dict
.
items
():
if
parallel_mode
==
"1d"
:
result_1d
=
profile_1d
(
input_tensor
,
config
,
args
)
print_rank_0
(
f
"INFO: Totoal time cost in 1D TP is
{
result_1d
}
."
)
if
parallel_mode
==
"2d"
:
result_2d
=
profile_2d
(
input_tensor
,
config
,
args
)
print_rank_0
(
f
"INFO: Totoal time cost in 2D TP is
{
result_2d
}
."
)
if
parallel_mode
==
"2p5d"
:
result_2p5d
=
profile_2p5d
(
input_tensor
,
config
,
args
)
print_rank_0
(
f
"INFO: Totoal time cost in 2P5D TP is
{
result_2p5d
}
."
)
if
parallel_mode
==
"3d"
:
result_3d
=
profile_3d
(
input_tensor
,
config
,
args
)
print_rank_0
(
f
"INFO: Totoal time cost in 3D TP is
{
result_3d
}
."
)
if
"2d"
not
in
config_dict
:
print_rank_0
(
f
"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices."
)
if
"2p5d"
not
in
config_dict
:
print_rank_0
(
f
"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices."
)
print_rank_0
(
f
"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices."
)
gpc
.
destroy
()
if
__name__
==
"__main__"
:
main
()
colossalai/cli/benchmark/simple_model.py
deleted
100644 → 0
View file @
ee222dfb
import
torch
import
colossalai
import
colossalai.nn
as
col_nn
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
=
256
):
super
().
__init__
()
intermediate_dim
=
dim
*
4
self
.
dense_1
=
col_nn
.
Linear
(
dim
,
intermediate_dim
)
self
.
activation
=
torch
.
nn
.
GELU
()
self
.
dense_2
=
col_nn
.
Linear
(
intermediate_dim
,
dim
)
self
.
dropout
=
col_nn
.
Dropout
(
0.1
)
def
forward
(
self
,
x
):
x
=
self
.
dense_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dense_2
(
x
)
x
=
self
.
dropout
(
x
)
return
x
colossalai/cli/benchmark/utils.py
View file @
a82da26f
import
math
import
time
from
grpc
import
Call
import
torch
from
.simple_model
import
MLP
from
colossalai.utils
import
Timer
,
synchronize
from
colossalai.utils
import
Multi
Timer
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
argparse
import
ArgumentParser
BATCH_SIZE
=
8
SEQ_LENGTH
=
120
HIDDEN_DIM
=
1024
ITER_TIMES
=
2000
def
build_args_parser
()
->
ArgumentParser
:
"""Helper function parsing the command line options."""
parser
=
ArgumentParser
(
description
=
"colossal benchmark"
)
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
default
=-
1
,
help
=
"Total number of devices to use."
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
BATCH_SIZE
,
help
=
"Batch size of the input tensor."
)
parser
.
add_argument
(
"--seq_len"
,
type
=
int
,
default
=
SEQ_LENGTH
,
help
=
"Sequence length of the input tensor."
)
parser
.
add_argument
(
"--hid_dim"
,
type
=
int
,
default
=
HIDDEN_DIM
,
help
=
"Hidden dimension of the input tensor."
)
parser
.
add_argument
(
"--num_steps"
,
type
=
int
,
default
=
ITER_TIMES
,
help
=
"The number of iteration times."
)
return
parser
def
build_input_tensor
(
args
):
return
torch
.
rand
(
args
.
bs
,
args
.
seq_len
,
args
.
hid_dim
)
def
build_configs_helper
(
device_cnt
:
int
):
config_dict
=
{}
if
device_cnt
<
2
:
return
config_dict
if
device_cnt
<
4
:
config_dict
[
"1d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)))
elif
device_cnt
<
8
:
config_dict
[
"1d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)))
config_dict
[
"2d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
4
,
mode
=
'2d'
)))
else
:
config_dict
[
"1d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
8
,
mode
=
'1d'
)))
config_dict
[
"2d"
]
=
dict
(
parallel
=
dict
(
data
=
2
,
tensor
=
dict
(
size
=
4
,
mode
=
'2d'
)))
config_dict
[
"2p5d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
8
,
mode
=
'2.5d'
,
depth
=
2
)))
config_dict
[
"3d"
]
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
8
,
mode
=
'3d'
)))
from
colossalai.context
import
ParallelMode
,
Config
from
typing
import
List
,
Dict
,
Tuple
,
Callable
def
get_time_stamp
()
->
int
:
"""
Return the time stamp for profiling.
Returns:
time_stamp (int): the time given by time.time()
"""
torch
.
cuda
.
synchronize
()
time_stamp
=
time
.
time
()
return
time_stamp
def
get_memory_states
()
->
Tuple
[
float
]:
"""
Return the memory statistics.
Returns:
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
"""
max_allocated
=
torch
.
cuda
.
max_memory_allocated
()
/
(
1024
**
3
)
max_cached
=
torch
.
cuda
.
max_memory_reserved
()
/
(
1024
**
3
)
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
empty_cache
()
return
max_allocated
,
max_cached
def
find_all_configs
(
device_cnt
:
int
)
->
List
[
Dict
]:
"""
Find all possible configurations for tensor parallelism
Args:
device_cnt (int): the number of devices
Returns:
config_list (List[Dict]): a list of configurations
"""
def
_is_square
(
num
):
return
math
.
floor
(
math
.
sqrt
(
num
))
**
2
==
num
def
_is_cube
(
num
):
return
math
.
floor
(
num
**
(
1.
/
3.
))
**
3
==
num
config_list
=
[]
# add non-parallel config
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
device_cnt
,
mode
=
None
)))
config_list
.
append
(
config
)
# add 1D config
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
device_cnt
,
mode
=
'1d'
)))
config_list
.
append
(
config
)
# add 1D config only if device_cnt is a square
if
_is_square
(
device_cnt
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
device_cnt
,
mode
=
'2d'
)))
config_list
.
append
(
config
)
# check for 2.5D
# iterate over depth
for
depth
in
range
(
1
,
device_cnt
):
if
device_cnt
%
depth
==
0
and
_is_square
(
device_cnt
//
depth
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
device_cnt
,
mode
=
'2.5d'
,
depth
=
depth
)))
config_list
.
append
(
config
)
# check for 3D if device_cnt is a cube
if
_is_cube
(
device_cnt
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
device_cnt
,
mode
=
'3d'
)))
config_list
.
append
(
config
)
config_list
=
[
Config
(
cfg
)
for
cfg
in
config_list
]
return
config_list
def
profile_model
(
model
:
torch
.
nn
.
Module
,
warmup_steps
:
int
,
profile_steps
:
int
,
data_func
:
Callable
,
timer
:
MultiTimer
)
->
Tuple
[
float
]:
"""
Profile the forward and backward of a model
Args:
model (torch.nn.Module): a PyTorch model
warmup_steps (int): the number of steps for warmup
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
return
config_dict
def
build_configs
(
args
):
total_device_cnt
=
torch
.
cuda
.
device_count
()
if
args
.
num_gpus
==
-
1
:
config_dict
=
build_configs_helper
(
total_device_cnt
)
else
:
valid_device_cnt
=
min
(
args
.
num_gpus
,
total_device_cnt
)
config_dict
=
build_configs_helper
(
valid_device_cnt
)
return
config_dict
def
profile_1d
(
input_tensor
,
config
,
args
):
gpc
.
load_config
(
config
)
gpc
.
init_parallel_groups
()
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_1D
)
model
=
MLP
(
args
.
hid_dim
).
cuda
()
input_tensor
=
input_tensor
.
cuda
()
torch
.
distributed
.
broadcast
(
input_tensor
,
src
=
0
)
timer
=
Timer
()
iter_times
=
args
.
num_steps
timer
.
start
()
for
i
in
range
(
iter_times
):
input_tensor
=
model
(
input_tensor
)
synchronize
()
result_1d
=
timer
.
stop
()
return
result_1d
def
profile_2d
(
input_tensor
,
config
,
args
):
gpc
.
load_config
(
config
)
gpc
.
init_parallel_groups
()
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2D_COL
)
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2D_ROW
)
model
=
MLP
(
args
.
hid_dim
).
cuda
()
input_tensor
=
input_tensor
.
cuda
()
torch
.
distributed
.
broadcast
(
input_tensor
,
src
=
0
)
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=
0
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)]
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=-
1
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)]
timer
=
Timer
()
iter_times
=
args
.
num_steps
timer
.
start
()
for
i
in
range
(
iter_times
):
input_tensor
=
model
(
input_tensor
)
synchronize
()
result_2d
=
timer
.
stop
()
return
result_2d
def
profile_2p5d
(
input_tensor
,
config
,
args
):
gpc
.
load_config
(
config
)
gpc
.
init_parallel_groups
()
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_COL
)
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
model
=
MLP
(
args
.
hid_dim
).
cuda
()
input_tensor
=
input_tensor
.
cuda
()
torch
.
distributed
.
broadcast
(
input_tensor
,
src
=
0
)
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=
0
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)]
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=
0
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)]
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=-
1
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)]
timer
=
Timer
()
iter_times
=
args
.
num_steps
timer
.
start
()
for
i
in
range
(
iter_times
):
input_tensor
=
model
(
input_tensor
)
synchronize
()
result_2p5d
=
timer
.
stop
()
return
result_2p5d
def
profile_3d
(
input_tensor
,
config
,
args
):
gpc
.
load_config
(
config
)
gpc
.
init_parallel_groups
()
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_3D_WEIGHT
)
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_3D_INPUT
)
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_3D_OUTPUT
)
model
=
MLP
(
args
.
hid_dim
).
cuda
()
input_tensor
=
input_tensor
.
cuda
()
torch
.
distributed
.
broadcast
(
input_tensor
,
src
=
0
)
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=
0
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_WEIGHT
)]
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=
0
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_INPUT
)]
input_tensor
=
torch
.
chunk
(
input_tensor
,
2
,
dim
=-
1
)[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_3D_OUTPUT
)]
timer
=
Timer
()
iter_times
=
args
.
num_steps
timer
.
start
()
for
i
in
range
(
iter_times
):
input_tensor
=
model
(
input_tensor
)
synchronize
()
result_3d
=
timer
.
stop
()
return
result_3d
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second
max_allocated (float): the maximum GPU memory allocated in GB
max_cached (float): the maximum GPU memory cached in GB
"""
def
_run_step
(
data
):
timer
.
start
(
'forward'
)
out
=
model
(
data
)
timer
.
stop
(
'forward'
,
keep_in_history
=
True
)
timer
.
start
(
'backward'
)
out
.
mean
().
backward
()
timer
.
stop
(
'backward'
,
keep_in_history
=
True
)
data_list
=
[
data_func
()
for
_
in
range
(
warmup_steps
)]
for
data
in
data_list
:
_run_step
(
data
)
timer
.
reset
(
'forward'
)
timer
.
reset
(
'backward'
)
for
_
in
range
(
profile_steps
):
data
=
data_func
()
_run_step
(
data
)
max_allocated
,
max_cached
=
get_memory_states
()
fwd_time
=
timer
.
get_timer
(
'forward'
).
get_history_mean
()
bwd_time
=
timer
.
get_timer
(
'backward'
).
get_history_mean
()
return
fwd_time
,
bwd_time
,
max_allocated
,
max_cached
def
get_batch_data
(
dim
:
int
,
batch_size
:
int
,
seq_length
:
int
,
mode
:
ParallelMode
)
->
torch
.
Tensor
:
"""
Return a random data of shape (batch_size, seq_length, dim) for profiling.
Args:
dim (int): hidden size
batch_size (int): the number of data samples
seq_length (int): the number of tokens
mode (ParallelMode): Colossal-AI ParallelMode enum
Returns:
data (torch.Tensor): random data
"""
if
mode
in
[
'2d'
,
'2.5d'
]:
batch_size
=
batch_size
//
2
dim
=
dim
//
2
elif
mode
==
'3d'
:
batch_size
=
batch_size
//
4
dim
=
dim
//
2
data
=
torch
.
rand
(
batch_size
,
seq_length
,
dim
).
cuda
()
return
data
colossalai/cli/cli.py
View file @
a82da26f
import
click
from
.launcher
import
run
from
.check
import
check
from
colossalai.cli.benchmark.utils
import
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
,
ITER_TIMES
from
colossalai.cli.benchmark.run
import
launch
as
col_benchmark
from
.benchmark
import
benchmark
class
Arguments
():
...
...
@@ -17,18 +16,6 @@ def cli():
pass
@
click
.
command
()
@
click
.
option
(
"--num_gpus"
,
type
=
int
,
default
=-
1
)
@
click
.
option
(
"--bs"
,
type
=
int
,
default
=
BATCH_SIZE
)
@
click
.
option
(
"--seq_len"
,
type
=
int
,
default
=
SEQ_LENGTH
)
@
click
.
option
(
"--hid_dim"
,
type
=
int
,
default
=
HIDDEN_DIM
)
@
click
.
option
(
"--num_steps"
,
type
=
int
,
default
=
ITER_TIMES
)
def
benchmark
(
num_gpus
,
bs
,
seq_len
,
hid_dim
,
num_steps
):
args_dict
=
locals
()
args
=
Arguments
(
args_dict
)
col_benchmark
(
args
)
cli
.
add_command
(
run
)
cli
.
add_command
(
check
)
cli
.
add_command
(
benchmark
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment