Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2ef855c7
Unverified
Commit
2ef855c7
authored
Mar 08, 2023
by
ramos
Committed by
GitHub
Mar 08, 2023
Browse files
support shardinit option to avoid OPT OOM initializing problem (#3037)
Co-authored-by:
poe
<
poe@nemoramo
>
parent
29386a54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
3 deletions
+33
-3
examples/language/opt/run_gemini.sh
examples/language/opt/run_gemini.sh
+8
-0
examples/language/opt/train_gemini_opt.py
examples/language/opt/train_gemini_opt.py
+25
-3
No files found.
examples/language/opt/run_gemini.sh
View file @
2ef855c7
...
@@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0}
...
@@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0}
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
export
MODEL
=
${
MODEL
:-
"125m"
}
export
MODEL
=
${
MODEL
:-
"125m"
}
export
GPUNUM
=
${
GPUNUM
:-
1
}
export
GPUNUM
=
${
GPUNUM
:-
1
}
export
USE_SHARD_INIT
=
${
USE_SHARD_INIT
:-
"false"
}
# make directory for logs
# make directory for logs
mkdir
-p
./logs
mkdir
-p
./logs
if
[
${
USE_SHARD_INIT
}
=
"true"
]
;
then
USE_SHARD_INIT
=
"--shardinit"
else
USE_SHARD_INIT
=
""
fi
export
MODLE_PATH
=
"facebook/opt-
${
MODEL
}
"
export
MODLE_PATH
=
"facebook/opt-
${
MODEL
}
"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
...
@@ -17,4 +24,5 @@ torchrun \
...
@@ -17,4 +24,5 @@ torchrun \
train_gemini_opt.py
\
train_gemini_opt.py
\
--mem_cap
${
MEMCAP
}
\
--mem_cap
${
MEMCAP
}
\
--model_name_or_path
${
MODLE_PATH
}
\
--model_name_or_path
${
MODLE_PATH
}
\
${
USE_SHARD_INIT
}
\
--batch_size
${
BS
}
2>&1 |
tee
./logs/colo_
${
MODEL
}
_bs_
${
BS
}
_cap_
${
MEMCAP
}
_gpu_
${
GPUNUM
}
.log
--batch_size
${
BS
}
2>&1 |
tee
./logs/colo_
${
MODEL
}
_bs_
${
BS
}
_cap_
${
MEMCAP
}
_gpu_
${
GPUNUM
}
.log
examples/language/opt/train_gemini_opt.py
View file @
2ef855c7
...
@@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
...
@@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ProcessGroup
,
ShardSpec
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
...
@@ -102,6 +104,11 @@ def parse_args():
...
@@ -102,6 +104,11 @@ def parse_args():
help
=
"Model type to use if training from scratch."
,
help
=
"Model type to use if training from scratch."
,
choices
=
MODEL_TYPES
,
choices
=
MODEL_TYPES
,
)
)
parser
.
add_argument
(
"--shardinit"
,
action
=
"store_true"
,
help
=
"Initialize the model with tensor parallel"
,
)
parser
.
add_argument
(
"--mem_cap"
,
type
=
int
,
default
=
0
,
help
=
"use mem cap"
)
parser
.
add_argument
(
"--mem_cap"
,
type
=
int
,
default
=
0
,
help
=
"use mem cap"
)
parser
.
add_argument
(
"--init_in_cpu"
,
action
=
'store_true'
,
default
=
False
,
help
=
"init training model in cpu"
)
parser
.
add_argument
(
"--init_in_cpu"
,
action
=
'store_true'
,
default
=
False
,
help
=
"init training model in cpu"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -159,16 +166,30 @@ def main():
...
@@ -159,16 +166,30 @@ def main():
else
:
else
:
init_dev
=
get_current_device
()
init_dev
=
get_current_device
()
# shard init prameters
if
args
.
shardinit
:
logger
.
info
(
"Sharding initialization !"
,
ranks
=
[
0
])
else
:
logger
.
info
(
"Skipping sharding initialization"
,
ranks
=
[
0
])
world_size
=
torch
.
distributed
.
get_world_size
()
shard_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
args
.
shardinit
else
None
default_dist_spec
=
ShardSpec
([
-
1
],
[
world_size
])
if
args
.
shardinit
else
None
# build model
# build model
if
args
.
model_name_or_path
is
None
or
args
.
model_name_or_path
==
'facebook/opt-13b'
:
if
args
.
model_name_or_path
is
None
or
args
.
model_name_or_path
==
'facebook/opt-13b'
:
# currently, there has a bug in pretrained opt-13b
# currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it
# we can not import it until huggingface fix it
logger
.
info
(
"Train a new model from scratch"
,
ranks
=
[
0
])
logger
.
info
(
"Train a new model from scratch"
,
ranks
=
[
0
])
with
ColoInitContext
(
device
=
init_dev
,
dtype
=
torch
.
half
):
with
ColoInitContext
(
device
=
init_dev
,
dtype
=
torch
.
half
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
shard_pg
):
model
=
OPTForCausalLM
(
config
)
model
=
OPTForCausalLM
(
config
)
else
:
else
:
logger
.
info
(
"Finetune a pre-trained model"
,
ranks
=
[
0
])
logger
.
info
(
"Finetune a pre-trained model"
,
ranks
=
[
0
])
with
ColoInitContext
(
device
=
init_dev
,
dtype
=
torch
.
half
):
with
ColoInitContext
(
device
=
init_dev
,
dtype
=
torch
.
half
,
default_dist_spec
=
default_dist_spec
,
default_pg
=
shard_pg
):
model
=
OPTForCausalLM
.
from_pretrained
(
args
.
model_name_or_path
,
model
=
OPTForCausalLM
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
".ckpt"
in
args
.
model_name_or_path
),
from_tf
=
bool
(
".ckpt"
in
args
.
model_name_or_path
),
config
=
config
,
config
=
config
,
...
@@ -179,7 +200,8 @@ def main():
...
@@ -179,7 +200,8 @@ def main():
numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
PLACEMENT_POLICY
=
'cpu'
PLACEMENT_POLICY
=
'cpu'
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
PLACEMENT_POLICY
,
pin_memory
=
True
)
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
PLACEMENT_POLICY
,
pin_memory
=
True
,
strict_ddp_mode
=
args
.
shardinit
)
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
14
,
gpu_margin_mem_ratio
=
0.0
)
optimizer
=
GeminiAdamOptimizer
(
model
,
lr
=
args
.
learning_rate
,
initial_scale
=
2
**
14
,
gpu_margin_mem_ratio
=
0.0
)
SEQ_LEN
=
1024
SEQ_LEN
=
1024
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment