Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
78fd31f9
Unverified
Commit
78fd31f9
authored
Mar 24, 2023
by
ver217
Committed by
GitHub
Mar 24, 2023
Browse files
[chatgpt] add precision option for colossalai (#3233)
parent
bd39877d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
...ications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
+12
-2
No files found.
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
View file @
78fd31f9
...
@@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy):
Args:
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator.
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
...
@@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy):
def
__init__
(
def
__init__
(
self
,
self
,
stage
:
int
=
3
,
stage
:
int
=
3
,
precision
:
str
=
'fp16'
,
seed
:
int
=
42
,
seed
:
int
=
42
,
shard_init
:
bool
=
False
,
# only for stage 3
shard_init
:
bool
=
False
,
# only for stage 3
placement_policy
:
str
=
'cuda'
,
placement_policy
:
str
=
'cuda'
,
...
@@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy):
norm_type
:
float
=
2.0
)
->
None
:
norm_type
:
float
=
2.0
)
->
None
:
super
().
__init__
(
seed
)
super
().
__init__
(
seed
)
assert
placement_policy
in
(
'cpu'
,
'cuda'
),
f
'Unsupported placement policy "
{
placement_policy
}
"'
assert
placement_policy
in
(
'cpu'
,
'cuda'
),
f
'Unsupported placement policy "
{
placement_policy
}
"'
assert
precision
in
(
'fp32'
,
'fp16'
),
f
'Unsupported precision "
{
precision
}
"'
self
.
stage
=
stage
self
.
stage
=
stage
# TODO(ver217): support shard_init when using from_pretrained()
# TODO(ver217): support shard_init when using from_pretrained()
if
shard_init
:
if
shard_init
:
warnings
.
warn
(
warnings
.
warn
(
f
'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
f
'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
)
)
if
stage
==
3
and
precision
==
'fp32'
:
warnings
.
warn
(
f
'Stage 3 only supports fp16. Precision is set to fp16.'
)
precision
=
'fp16'
self
.
precision
=
precision
self
.
shard_init
=
shard_init
self
.
shard_init
=
shard_init
self
.
gemini_config
=
dict
(
device
=
get_current_device
(),
self
.
gemini_config
=
dict
(
device
=
get_current_device
(),
placement_policy
=
placement_policy
,
placement_policy
=
placement_policy
,
...
@@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy):
return
super
().
model_init_context
()
return
super
().
model_init_context
()
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
return
zero_model_wrapper
(
model
,
zero_stage
=
self
.
stage
,
gemini_config
=
self
.
gemini_config
)
model
=
zero_model_wrapper
(
model
,
zero_stage
=
self
.
stage
,
gemini_config
=
self
.
gemini_config
)
if
self
.
stage
!=
3
and
self
.
precision
==
'fp16'
:
model
=
model
.
half
()
return
model
def
setup_optimizer
(
self
,
optimizer
:
optim
.
Optimizer
,
model
:
nn
.
Module
)
->
optim
.
Optimizer
:
def
setup_optimizer
(
self
,
optimizer
:
optim
.
Optimizer
,
model
:
nn
.
Module
)
->
optim
.
Optimizer
:
assert
isinstance
(
optimizer
,
(
CPUAdam
,
HybridAdam
)),
f
'Unsupported optimizer
{
type
(
optimizer
)
}
'
assert
isinstance
(
optimizer
,
(
CPUAdam
,
HybridAdam
)),
f
'Unsupported optimizer
{
type
(
optimizer
)
}
'
...
@@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy):
# merge lora_weights into weights
# merge lora_weights into weights
for
module
in
unwrapped_model
.
modules
():
for
module
in
unwrapped_model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
merge_weights
=
True
module
.
eval
()
module
.
eval
()
# get state_dict and save
# get state_dict and save
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment