Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
61e68783
Commit
61e68783
authored
Feb 27, 2023
by
zbian
Committed by
アマデウス
Feb 28, 2023
Browse files
fixed using zero with tp cannot access weight correctly
parent
eb5cf943
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
68 deletions
+72
-68
colossalai/nn/layer/colossalai_layer/_utils.py
colossalai/nn/layer/colossalai_layer/_utils.py
+41
-38
colossalai/nn/layer/colossalai_layer/dropout.py
colossalai/nn/layer/colossalai_layer/dropout.py
+31
-30
No files found.
colossalai/nn/layer/colossalai_layer/_utils.py
View file @
61e68783
...
@@ -24,15 +24,18 @@ class ColossalaiModule(nn.Module):
...
@@ -24,15 +24,18 @@ class ColossalaiModule(nn.Module):
def
__init__
(
self
,
module
:
nn
.
Module
,
**
kwargs
):
def
__init__
(
self
,
module
:
nn
.
Module
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
# copy values
self
.
module
=
module
self
.
__dict__
=
module
.
__dict__
.
copy
()
# copy methods
for
name
,
attr
in
module
.
__class__
.
__dict__
.
items
():
if
name
not
in
[
'__init__'
,
'forward'
]
and
callable
(
attr
):
setattr
(
self
,
name
,
getattr
(
module
,
name
))
self
.
_forward_func
=
module
.
forward
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
def
__getattr__
(
self
,
name
:
str
):
if
name
==
'module'
:
return
super
().
__getattr__
(
name
)
elif
hasattr
(
self
.
module
,
name
):
return
getattr
(
self
.
module
,
name
)
elif
name
in
self
.
__dict__
:
return
self
.
__dict__
[
name
]
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
).
__name__
,
name
))
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
):
return
self
.
_forward_func
(
*
args
)
return
self
.
module
(
*
args
)
colossalai/nn/layer/colossalai_layer/dropout.py
View file @
61e68783
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
..parallel_1d
import
*
from
..parallel_1d
import
*
...
@@ -24,7 +25,7 @@ class Dropout(ColossalaiModule):
...
@@ -24,7 +25,7 @@ class Dropout(ColossalaiModule):
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
):
if
self
.
tensor_parallel
in
[
None
,
'1d'
]:
if
self
.
tensor_parallel
in
[
None
,
'1d'
]:
return
s
elf
.
_
forward
_func
(
*
args
)
return
s
uper
().
forward
(
*
args
)
else
:
else
:
with
seed
(
ParallelMode
.
TENSOR
):
with
seed
(
ParallelMode
.
TENSOR
):
return
s
elf
.
_
forward
_func
(
*
args
)
return
s
uper
().
forward
(
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment