Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bbac6760
Unverified
Commit
bbac6760
authored
Mar 23, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 23, 2023
Browse files
fix torch version (#3225)
parent
fa97a9ca
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
6 deletions
+16
-6
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
...ications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
+15
-5
applications/ChatGPT/requirements.txt
applications/ChatGPT/requirements.txt
+1
-1
No files found.
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
View file @
bbac6760
...
@@ -9,6 +9,10 @@ from chatgpt.models.base import Actor
...
@@ -9,6 +9,10 @@ from chatgpt.models.base import Actor
from
chatgpt.models.lora
import
LoraLinear
from
chatgpt.models.lora
import
LoraLinear
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
import
colossalai
import
colossalai
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
...
@@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy):
return
model
.
module
return
model
.
module
return
model
return
model
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
unwrapped_model
=
self
.
_unwrap_model
(
model
)
# TODO : better way to get torch model from gemini model
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
# to get torch model from gemini model
...
@@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy):
module
.
merge_weights
=
True
module
.
merge_weights
=
True
module
.
eval
()
module
.
eval
()
# get state_dict and save
# get state_dict and save
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
state_dict
=
unwrapped_model
.
state_dict
()
state_dict
=
unwrapped_model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
return
torch
.
save
(
state_dict
,
path
)
torch
.
save
(
state_dict
,
path
)
else
:
self
.
model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
:
if
only_rank0
:
...
...
applications/ChatGPT/requirements.txt
View file @
bbac6760
...
@@ -3,5 +3,5 @@ tqdm
...
@@ -3,5 +3,5 @@ tqdm
datasets
datasets
loralib
loralib
colossalai>=0.2.4
colossalai>=0.2.4
torch
torch
==1.12.1
langchain
langchain
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment