Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9dd0365
Unverified
Commit
c9dd0365
authored
Mar 10, 2023
by
BlueRum
Committed by
GitHub
Mar 10, 2023
Browse files
[chatgpt] fix lora save bug (#3099)
* fix colo-stratergy * polish * fix lora * fix ddp * polish * polish
parent
018936a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
2 deletions
+29
-2
applications/ChatGPT/chatgpt/models/lora.py
applications/ChatGPT/chatgpt/models/lora.py
+3
-0
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
...ications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
+16
-0
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
+10
-2
No files found.
applications/ChatGPT/chatgpt/models/lora.py
View file @
c9dd0365
...
...
@@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
if
self
.
r
>
0
:
self
.
weight
.
data
+=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
delattr
(
self
,
'lora_A'
)
delattr
(
self
,
'lora_B'
)
self
.
merged
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
...
@@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
return
convert_to_lora_recursively
(
self
,
self
.
lora_rank
)
lora
.
mark_only_lora_as_trainable
(
self
,
self
.
lora_train_bias
)
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
View file @
c9dd0365
...
...
@@ -6,11 +6,13 @@ import torch.distributed as dist
import
torch.nn
as
nn
import
torch.optim
as
optim
from
chatgpt.models.base
import
Actor
from
chatgpt.models.lora
import
LoraLinear
from
torch.optim
import
Optimizer
import
colossalai
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.tensor
import
ProcessGroup
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
...
...
@@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if
isinstance
(
unwrapped_model
,
ZeroDDP
):
state_dict
=
unwrapped_model
.
state_dict
()
unwrapped_model
=
get_static_torch_model
(
unwrapped_model
)
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
unwrapped_model
.
load_state_dict
(
state_dict
)
# merge lora_weights into weights
for
module
in
unwrapped_model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
# get state_dict and save
state_dict
=
unwrapped_model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
...
...
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
View file @
c9dd0365
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
chatgpt.models.base
import
Actor
from
chatgpt.models.lora
import
LoraLinear
from
chatgpt.replay_buffer
import
ReplayBuffer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Optimizer
...
...
@@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
return
model
.
module
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
super
().
save_model
(
model
,
path
,
only_rank0
)
model
=
model
.
model
.
module
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment