Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d5e3e3ec
"vscode:/vscode.git/clone" did not exist on "a09f88ab0787bed82d33d9d7b68b759beb8e8e06"
Unverified
Commit
d5e3e3ec
authored
Dec 28, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 28, 2022
Browse files
[example] update gpt example for larger model scale (#2211)
parent
24246f7a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
122 additions
and
57 deletions
+122
-57
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+0
-2
examples/language/gpt/README.md
examples/language/gpt/README.md
+17
-3
examples/language/gpt/model_zoo.py
examples/language/gpt/model_zoo.py
+71
-0
examples/language/gpt/run.sh
examples/language/gpt/run.sh
+7
-4
examples/language/gpt/train_gpt_demo.py
examples/language/gpt/train_gpt_demo.py
+27
-48
No files found.
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
d5e3e3ec
...
...
@@ -59,7 +59,6 @@ class MemStatsCollector:
return
[
t
-
self
.
_sampling_time
[
0
]
for
t
in
self
.
_sampling_time
]
def
start_collection
(
self
):
print
(
'start collection'
)
self
.
_start_flag
=
True
self
.
_mem_monitor
.
start
()
...
...
@@ -68,7 +67,6 @@ class MemStatsCollector:
# self._step_total = len(self._sampling_time)
self
.
_step_total
=
len
(
self
.
_memstats
.
non_model_data_list
(
'cuda'
))
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
print
(
f
'finish_collection
{
self
.
_step_total
}
'
)
# deprecated
...
...
examples/language/gpt/README.md
View file @
d5e3e3ec
...
...
@@ -62,7 +62,7 @@ ColossalAI version 0.1.13.
How dose Batch Size affect the efficency.
| model | #GPU | policy | TP |batch | Tflops |
| model | #GPU | policy | TP |
batch
per DP
| Tflops |
| ---------- | --------- |--------- |--------- |--------- |--------- |
| gpt2_10b | 2 | cpu | 1 | 32 | 122.046 |
| gpt2_10b | 2 | cpu | 1 | 16 | 82.649 |
...
...
@@ -71,7 +71,7 @@ How dose Batch Size affect the efficency.
How dose the Placement Policy affect the efficency.
| model | #GPU | policy | TP |batch | Tflops |
| model | #GPU | policy | TP |
batch
per DP
| Tflops |
| ---------- | --------- |--------- |--------- |--------- |--------- |
| gpt2_10b | 4 | auto | 1 | 8 | 88.657 |
| gpt2_10b | 4 | cuda | 1 | 8 | OOM |
...
...
@@ -80,9 +80,23 @@ How dose the Placement Policy affect the efficency.
How dose the Tensor Parallel Degree affect the efficency.
| model | #GPU | policy | TP |batch | Tflops |
| model | #GPU | policy | TP |
batch
per DP
| Tflops |
| ---------- | --------- |--------- |--------- |--------- |--------- |
| gpt2_10b | 4 | auto | 1 | 8 | 88.657 |
| gpt2_10b | 4 | auto | 2 | 8 | 56.687 |
| gpt2_10b | 4 | auto | 4 | 8 | 29.019 |
| gpt2_10b | 4 | auto | 4 | 64 | 50.411 |
| gpt2_20b | 1 | cpu | 1 | 8 | 43.102 |
| gpt2_20b | 4 | cpu | 4 | 8 | 28.491 |
Touch the bar of model scale and batch size.
| model | #GPU | policy | TP | batch per DP | Tflops |
| ---------- | --------- |--------- |--------- |--------- |--------- |
| gpt2_20b | 4 | cpu | 1 | 64 | CUDA OOM |
| gpt2_20b | 4 | auto | 1/2 | 64 | CUDA OOM |
| gpt2_20b | 4 | cpu | 2 | 64 | 121.394 |
| gpt2_20b | 4 | cpu | 2 | 8 | 43.102 |
| gpt2_20b | 8 | cpu | 2 | 64 | 125.170 |
examples/language/gpt/model_zoo.py
0 → 100644
View file @
d5e3e3ec
from
torch
import
nn
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class
GPTLMModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_seq_len
=
1024
,
vocab_size
=
50257
,
checkpoint
=
False
):
super
().
__init__
()
self
.
checkpoint
=
checkpoint
self
.
model
=
GPT2LMHeadModel
(
GPT2Config
(
n_embd
=
hidden_size
,
n_layer
=
num_layers
,
n_head
=
num_attention_heads
,
n_positions
=
max_seq_len
,
n_ctx
=
max_seq_len
,
vocab_size
=
vocab_size
))
if
checkpoint
:
self
.
model
.
gradient_checkpointing_enable
()
def
forward
(
self
,
input_ids
,
attention_mask
):
# Only return lm_logits
return
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
not
self
.
checkpoint
)[
0
]
def
gpt2_medium
(
checkpoint
=
False
):
return
GPTLMModel
(
hidden_size
=
1024
,
num_layers
=
24
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
gpt2_xl
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
1600
,
num_layers
=
48
,
num_attention_heads
=
32
,
checkpoint
=
checkpoint
)
def
gpt2_10b
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
4096
,
num_layers
=
50
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
gpt2_14b
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
4096
,
num_layers
=
70
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
gpt2_20b
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
8192
,
num_layers
=
25
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
gpt2_24b
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
8192
,
num_layers
=
30
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
model_builder
(
model_size
:
str
):
if
model_size
==
"gpt2_medium"
:
return
gpt2_medium
elif
model_size
==
"gpt2_xl"
:
return
gpt2_xl
elif
model_size
==
"gpt2_10b"
:
return
gpt2_10b
elif
model_size
==
"gpt2_14b"
:
return
gpt2_14b
elif
model_size
==
"gpt2_20b"
:
return
gpt2_20b
elif
model_size
==
"gpt2_24b"
:
return
gpt2_24b
__all__
=
[
'model_builder'
]
examples/language/gpt/run.sh
View file @
d5e3e3ec
...
...
@@ -2,9 +2,12 @@
export
DISTPAN
=
"colossalai"
# The following options only valid when DISTPAN="colossalai"
export
TPDEGREE
=
4
export
GPUNUM
=
4
export
PLACEMENT
=
'
auto
'
export
TPDEGREE
=
2
export
GPUNUM
=
8
export
PLACEMENT
=
'
cpu
'
export
USE_SHARD_INIT
=
False
export
BATCH_SIZE
=
64
export
MODEL_TYPE
=
"gpt2_20b"
env
OMP_NUM_THREADS
=
16 torchrun
--standalone
--nproc_per_node
=
${
GPUNUM
}
train_gpt_demo.py
--tp_degree
=
${
TPDEGREE
}
--placement
${
PLACEMENT
}
--shardinit
${
USE_SHARD_INIT
}
--distplan
${
DISTPAN
}
2>&1 |
tee
run.log
mkdir
-p
logs
env
OMP_NUM_THREADS
=
16 torchrun
--standalone
--nproc_per_node
=
${
GPUNUM
}
train_gpt_demo.py
--tp_degree
=
${
TPDEGREE
}
--model_type
=
${
MODEL_TYPE
}
--batch_size
=
${
BATCH_SIZE
}
--placement
${
PLACEMENT
}
--shardinit
${
USE_SHARD_INIT
}
--distplan
${
DISTPAN
}
2>&1 |
tee
./logs/
${
MODEL_TYPE
}
_
${
DISTPAN
}
_gpu_
${
GPUNUM
}
_bs_
${
BATCH_SIZE
}
_tp_
${
TPDEGREE
}
.log
examples/language/gpt/train_gpt_demo.py
View file @
d5e3e3ec
...
...
@@ -6,18 +6,16 @@ import torch
import
torch.nn
as
nn
from
packaging
import
version
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.gemini_optimizer
import
GeminiAdamOptimizer
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ReplicaSpec
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.zero.sharded_optim
import
LowLevelZeroOptimizer
from
model_zoo
import
model_builder
def
parse_args
():
...
...
@@ -47,6 +45,18 @@ def parse_args():
help
=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan."
,
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
8
,
help
=
"batch size per DP group of training."
,
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
'gpt2_medium'
,
help
=
"model model scale"
,
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -65,33 +75,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d
(
-
1
,
param
,
pg
)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class
GPTLMModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_seq_len
=
1024
,
vocab_size
=
50257
,
checkpoint
=
False
):
super
().
__init__
()
self
.
checkpoint
=
checkpoint
self
.
model
=
GPT2LMHeadModel
(
GPT2Config
(
n_embd
=
hidden_size
,
n_layer
=
num_layers
,
n_head
=
num_attention_heads
,
n_positions
=
max_seq_len
,
n_ctx
=
max_seq_len
,
vocab_size
=
vocab_size
))
if
checkpoint
:
self
.
model
.
gradient_checkpointing_enable
()
def
forward
(
self
,
input_ids
,
attention_mask
):
# Only return lm_logits
return
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
not
self
.
checkpoint
)[
0
]
class
GPTLMLoss
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -112,18 +95,6 @@ def get_data(batch_size, seq_len, vocab_size):
return
input_ids
,
attention_mask
def
gpt2_medium
(
checkpoint
=
False
):
return
GPTLMModel
(
hidden_size
=
1024
,
num_layers
=
24
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
gpt2_xl
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
1600
,
num_layers
=
48
,
num_attention_heads
=
32
,
checkpoint
=
checkpoint
)
def
gpt2_10b
(
checkpoint
=
True
):
return
GPTLMModel
(
hidden_size
=
4096
,
num_layers
=
50
,
num_attention_heads
=
16
,
checkpoint
=
checkpoint
)
def
get_cpu_mem
():
return
psutil
.
Process
().
memory_info
().
rss
/
1024
**
2
...
...
@@ -210,7 +181,8 @@ def main():
if
args
.
distplan
not
in
[
"colossalai"
,
"torch_ddp"
,
"torch_zero"
,
"zero1"
,
"zero2"
]:
raise
TypeError
(
f
"
{
args
.
distplan
}
is error"
)
BATCH_SIZE
=
64
# batch size per DP degree
BATCH_SIZE
=
args
.
batch_size
SEQ_LEN
=
1024
VOCAB_SIZE
=
50257
...
...
@@ -220,7 +192,7 @@ def main():
colossalai
.
launch_from_torch
(
config
=
{})
logger
=
get_dist_logger
()
logger
.
info
(
f
"
using dist plan
{
args
.
distplan
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
"
{
args
.
model_type
}
,
{
args
.
distplan
}
, batch size
{
BATCH_SIZE
}
"
,
ranks
=
[
0
])
# build criterion
criterion
=
GPTLMLoss
()
...
...
@@ -232,8 +204,11 @@ def main():
default_dist_spec
=
ShardSpec
([
-
1
],
[
args
.
tp_degree
])
if
args
.
shardinit
else
None
# build GPT model
with
ColoInitContext
(
device
=
get_current_device
(),
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
):
model
=
gpt2_10b
(
checkpoint
=
True
)
with
ColoInitContext
(
device
=
get_current_device
(),
dtype
=
torch
.
half
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
default_pg
):
model
=
model_builder
(
args
.
model_type
)(
checkpoint
=
True
)
pg
=
default_pg
# Tensor Parallelism (TP)
...
...
@@ -246,7 +221,7 @@ def main():
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
2
**
5
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init optim, '
),
ranks
=
[
0
])
else
:
model
=
gpt2_10b
(
checkpoint
=
True
).
cuda
()
model
=
model_builder
(
args
.
model_type
)
(
checkpoint
=
True
).
cuda
()
if
args
.
distplan
.
startswith
(
"torch"
):
model
=
DDP
(
model
)
...
...
@@ -262,10 +237,14 @@ def main():
overlap_communication
=
True
,
partition_grad
=
partition_flag
,
verbose
=
True
)
# notice that the model is still in fp32
# model is shared after TP
numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
logger
.
info
(
get_mem_info
(
prefix
=
'After init model, '
),
ranks
=
[
0
])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
# = batch_per_DP_group * numel * seq_len * 8
get_tflops_func
=
partial
(
get_tflops
,
numel
,
BATCH_SIZE
,
SEQ_LEN
)
torch
.
cuda
.
synchronize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment