Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a88bc828
Unverified
Commit
a88bc828
authored
Feb 16, 2023
by
ver217
Committed by
GitHub
Feb 16, 2023
Browse files
[chatgpt] disable shard init for colossalai (#2767)
parent
d6d6dec1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
1 deletion
+7
-1
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
...ications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
+7
-1
No files found.
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
View file @
a88bc828
import
warnings
from
typing
import
Optional
import
torch
...
...
@@ -23,6 +24,7 @@ class ColossalAIStrategy(DDPStrategy):
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
...
...
@@ -50,7 +52,7 @@ class ColossalAIStrategy(DDPStrategy):
self
,
stage
:
int
=
3
,
seed
:
int
=
42
,
shard_init
:
bool
=
Tru
e
,
# only for stage 3
shard_init
:
bool
=
Fals
e
,
# only for stage 3
placement_policy
:
str
=
'cuda'
,
pin_memory
:
bool
=
True
,
# only for stage 3
force_outputs_fp32
:
bool
=
False
,
# only for stage 3
...
...
@@ -72,6 +74,10 @@ class ColossalAIStrategy(DDPStrategy):
super
().
__init__
(
seed
)
assert
placement_policy
in
(
'cpu'
,
'cuda'
),
f
'Unsupported placement policy "
{
placement_policy
}
"'
self
.
stage
=
stage
# TODO(ver217): support shard_init when using from_pretrained()
if
shard_init
:
warnings
.
warn
(
f
'Shard init is not supported yet. Ignore.'
)
shard_init
=
False
self
.
shard_init
=
shard_init
self
.
gemini_config
=
dict
(
device
=
get_current_device
(),
placement_policy
=
placement_policy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment