Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c9dd0365
Unverified
Commit
c9dd0365
authored
Mar 10, 2023
by
BlueRum
Committed by
GitHub
Mar 10, 2023
Browse files
[chatgpt] fix lora save bug (#3099)
* fix colo-stratergy * polish * fix lora * fix ddp * polish * polish
parent
018936a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
2 deletions
+29
-2
applications/ChatGPT/chatgpt/models/lora.py
applications/ChatGPT/chatgpt/models/lora.py
+3
-0
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
...ications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
+16
-0
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
+10
-2
No files found.
applications/ChatGPT/chatgpt/models/lora.py
View file @
c9dd0365
...
@@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
...
@@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
# Merge the weights and mark it
if
self
.
r
>
0
:
if
self
.
r
>
0
:
self
.
weight
.
data
+=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
self
.
weight
.
data
+=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
delattr
(
self
,
'lora_A'
)
delattr
(
self
,
'lora_B'
)
self
.
merged
=
True
self
.
merged
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
@@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
...
@@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
return
return
convert_to_lora_recursively
(
self
,
self
.
lora_rank
)
convert_to_lora_recursively
(
self
,
self
.
lora_rank
)
lora
.
mark_only_lora_as_trainable
(
self
,
self
.
lora_train_bias
)
lora
.
mark_only_lora_as_trainable
(
self
,
self
.
lora_train_bias
)
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
View file @
c9dd0365
...
@@ -6,11 +6,13 @@ import torch.distributed as dist
...
@@ -6,11 +6,13 @@ import torch.distributed as dist
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
chatgpt.models.base
import
Actor
from
chatgpt.models.base
import
Actor
from
chatgpt.models.lora
import
LoraLinear
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
import
colossalai
import
colossalai
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.nn.parallel.utils
import
get_static_torch_model
from
colossalai.tensor
import
ProcessGroup
,
ShardSpec
from
colossalai.tensor
import
ProcessGroup
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
...
@@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
...
@@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
unwrapped_model
=
self
.
_unwrap_model
(
model
)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if
isinstance
(
unwrapped_model
,
ZeroDDP
):
state_dict
=
unwrapped_model
.
state_dict
()
unwrapped_model
=
get_static_torch_model
(
unwrapped_model
)
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
unwrapped_model
.
load_state_dict
(
state_dict
)
# merge lora_weights into weights
for
module
in
unwrapped_model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
# get state_dict and save
state_dict
=
unwrapped_model
.
state_dict
()
state_dict
=
unwrapped_model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
return
...
...
applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
View file @
c9dd0365
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
chatgpt.models.base
import
Actor
from
chatgpt.models.base
import
Actor
from
chatgpt.models.lora
import
LoraLinear
from
chatgpt.replay_buffer
import
ReplayBuffer
from
chatgpt.replay_buffer
import
ReplayBuffer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
...
@@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
return
model
.
module
return
model
.
module
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
return
super
().
save_model
(
model
,
path
,
only_rank0
)
model
=
model
.
model
.
module
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment