Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a25f7553
Unverified
Commit
a25f7553
authored
Nov 08, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 08, 2022
Browse files
[example] add TP to GPT example (#1828)
parent
49216d7a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
56 deletions
+114
-56
examples/language/gpt/run.sh
examples/language/gpt/run.sh
+1
-1
examples/language/gpt/train_gpt_demo.py
examples/language/gpt/train_gpt_demo.py
+104
-24
examples/language/opt/run_clm.py
examples/language/opt/run_clm.py
+9
-3
examples/language/opt/utils.py
examples/language/opt/utils.py
+0
-28
No files found.
examples/language/gpt/run.sh
View file @
a25f7553
env
OMP_NUM_THREADS
=
16 torchrun
--standalone
--nproc_per_node
=
2
train_gpt_demo.py 2>&1 |
tee
run.log
env
OMP_NUM_THREADS
=
16 torchrun
--standalone
--nproc_per_node
=
4
train_gpt_demo.py
--tp_degree
=
2
--placement
=
'cpu'
2>&1 |
tee
run.log
examples/language/gpt/train_gpt_demo.py
View file @
a25f7553
...
...
@@ -10,13 +10,48 @@ import colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
def
parse_args
():
parser
=
colossalai
.
get_default_parser
()
parser
.
add_argument
(
"--tp_degree"
,
type
=
int
,
default
=
1
,
help
=
"Tensor Parallelism Degree."
,
)
parser
.
add_argument
(
"--placement"
,
type
=
str
,
default
=
'cpu'
,
help
=
"Placement Policy for Gemini."
,
)
args
=
parser
.
parse_args
()
return
args
## Parameter Sharding Strategies for Tensor Parallelism
def
split_param_single_dim_tp1d
(
dim
:
int
,
param
:
ColoParameter
,
pg
:
ProcessGroup
):
spec
=
(
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
if
param
.
process_group
.
tp_world_size
()
==
1
:
param
.
set_process_group
(
pg
)
param
.
set_tensor_spec
(
*
spec
)
def
split_param_row_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
0
,
param
,
pg
)
def
split_param_col_tp1d
(
param
:
ColoParameter
,
pg
:
ProcessGroup
):
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class
GPTLMModel
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -56,6 +91,7 @@ class GPTLMLoss(nn.Module):
return
self
.
loss_fn
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
## Randomly Generated Data
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
@@ -90,54 +126,96 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
# Tensor Parallel
def
tensor_parallelize
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for
mn
,
module
in
model
.
named_modules
():
for
pn
,
param
in
module
.
named_parameters
(
recurse
=
False
):
# set process group for all parameters
param
.
set_process_group
(
pg
)
if
'mlp.c_fc'
in
mn
:
if
'weight'
in
pn
or
'bias'
in
pn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
# keep the shape of the output from c_fc
param
.
compute_spec
.
set_output_replicate
(
False
)
elif
'mlp.c_proj'
in
mn
:
if
'weight'
in
pn
:
split_param_row_tp1d
(
param
,
pg
)
# row slice
elif
'wte'
in
mn
or
'wpe'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
elif
'c_attn'
in
mn
or
'c_proj'
in
mn
:
split_param_col_tp1d
(
param
,
pg
)
# colmn slice
# Gemini + ZeRO DDP
def
gemini_zero_dpp
(
model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
placememt_policy
:
str
=
"auto"
):
cai_version
=
colossalai
.
__version__
if
version
.
parse
(
cai_version
)
>
version
.
parse
(
"0.1.10"
):
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
placememt_policy
,
pin_memory
=
True
,
search_range_mb
=
32
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
gemini_manager
=
GeminiManager
(
placememt_policy
,
chunk_manager
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
True
,
init_device
=
GeminiManager
.
get_default_device
(
placememt_policy
))
model
=
ZeroDDP
(
model
,
gemini_manager
)
else
:
raise
NotImplemented
(
f
"CAI version
{
cai_version
}
is not supported"
)
return
model
def
main
():
args
=
parse_args
()
BATCH_SIZE
=
8
SEQ_LEN
=
1024
VOCAB_SIZE
=
50257
NUM_STEPS
=
10
PLACEMENT_POLICY
=
'auto'
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
{})
pg
=
ProcessGroup
()
logger
=
get_dist_logger
()
pg
=
ProcessGroup
(
tp_degree
=
args
.
tp_degree
)
logger
=
get_dist_logger
()
logger
.
info
(
get_mem_info
(),
ranks
=
[
0
])
# build GPT model
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
gpt2_medium
(
checkpoint
=
True
)
numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
logger
.
info
(
f
'Model numel:
{
numel
}
'
,
ranks
=
[
0
])
get_tflops_func
=
partial
(
get_tflops
,
numel
,
BATCH_SIZE
,
SEQ_LEN
)
cai_version
=
colossalai
.
__version__
logger
.
info
(
f
'using Colossal-AI version
{
cai_version
}
'
)
if
version
.
parse
(
cai_version
)
>
version
.
parse
(
"0.1.10"
):
from
colossalai.nn.parallel
import
GeminiDDP
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
PLACEMENT_POLICY
,
pin_memory
=
True
,
search_range_mb
=
32
)
elif
version
.
parse
(
cai_version
)
<=
version
.
parse
(
"0.1.10"
)
and
version
.
parse
(
cai_version
)
>=
version
.
parse
(
"0.1.9"
):
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
64
*
1024
**
2
,
32
)
gemini_manager
=
GeminiManager
(
PLACEMENT_POLICY
,
chunk_manager
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
True
,
init_device
=
GeminiManager
.
get_default_device
(
PLACEMENT_POLICY
))
model
=
ZeroDDP
(
model
,
gemini_manager
)
# Tensor Parallelism (TP)
tensor_parallelize
(
model
,
pg
)
# Gemini + ZeRO DP, Note it must be used after TP
model
=
gemini_zero_dpp
(
model
,
pg
,
args
.
placement
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init model, '
),
ranks
=
[
0
])
# build criterion
criterion
=
GPTLMLoss
()
# optimizer
#
build
optimizer
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
**
5
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
torch
.
cuda
.
synchronize
()
model
.
train
()
for
n
in
range
(
NUM_STEPS
):
# we just use randomly generated data here
...
...
@@ -156,6 +234,8 @@ def main():
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Loss:
{
loss
.
item
():.
3
f
}
, Step time:
{
step_time
:.
3
f
}
s, TFLOPS:
{
get_tflops_func
(
step_time
):.
3
f
}
'
,
ranks
=
[
0
])
torch
.
cuda
.
synchronize
()
if
__name__
==
'__main__'
:
main
()
examples/language/opt/run_clm.py
View file @
a25f7553
...
...
@@ -36,7 +36,6 @@ from datasets import load_dataset
from
packaging
import
version
from
torch.utils.data
import
DataLoader
from
tqdm.auto
import
tqdm
from
utils
import
colo_memory_cap
import
colossalai
import
transformers
...
...
@@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.utils
import
get_current_device
,
get_dataloader
from
colossalai.utils.checkpoint
import
load_checkpoint
,
save_checkpoint
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero
import
ZeroOptimizer
from
transformers
import
(
...
...
@@ -249,12 +247,20 @@ def parse_args():
return
args
def
colo_memory_cap
(
size_in_GB
):
from
colossalai.utils
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
,
get_current_device
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
if
size_in_GB
*
(
1024
**
3
)
<
cuda_capacity
:
colo_set_process_memory_fraction
(
size_in_GB
*
(
1024
**
3
)
/
cuda_capacity
)
print
(
"Using {} GB of GPU memory"
.
format
(
size_in_GB
))
def
main
():
args
=
parse_args
()
disable_existing_loggers
()
colossalai
.
launch_from_torch
(
config
=
dict
())
logger
=
get_dist_logger
()
is_main_process
=
gpc
.
get_
local_rank
(
ParallelMode
.
DATA
)
==
0
is_main_process
=
dist
.
get_
rank
(
)
==
0
if
is_main_process
:
datasets
.
utils
.
logging
.
set_verbosity_warning
()
...
...
examples/language/opt/utils.py
deleted
100644 → 0
View file @
49216d7a
import
torch
import
torch.distributed
as
dist
def
memory_cap
(
size_in_GB
):
print
(
f
"use only
{
size_in_GB
}
GB of CUDA memory"
)
assert
dist
.
is_initialized
(),
"memory_cap must be used after dist init"
local_rank
=
dist
.
get_rank
()
cuda_capacity
=
torch
.
cuda
.
get_device_properties
(
local_rank
).
total_memory
size_in_B
=
(
size_in_GB
*
1024
**
3
)
if
size_in_B
>
cuda_capacity
:
print
(
f
'memory_cap is uselsess since
{
cuda_capacity
/
1024
**
3
}
less than
{
size_in_GB
}
'
)
return
fraction
=
(
size_in_GB
*
1024
**
3
)
/
cuda_capacity
print
(
f
'mem faction is
{
fraction
}
'
)
torch
.
cuda
.
set_per_process_memory_fraction
(
fraction
,
local_rank
)
def
colo_memory_cap
(
size_in_GB
):
from
colossalai.utils
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
,
get_current_device
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
if
size_in_GB
*
(
1024
**
3
)
<
cuda_capacity
:
colo_set_process_memory_fraction
(
size_in_GB
*
(
1024
**
3
)
/
cuda_capacity
)
print
(
"Using {} GB of GPU memory"
.
format
(
size_in_GB
))
if
__name__
==
'__main__'
:
memory_cap
(
40
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment