Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
61e68783
Commit
61e68783
authored
Feb 27, 2023
by
zbian
Committed by
アマデウス
Feb 28, 2023
Browse files
fixed using zero with tp cannot access weight correctly
parent
eb5cf943
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
68 deletions
+72
-68
colossalai/nn/layer/colossalai_layer/_utils.py
colossalai/nn/layer/colossalai_layer/_utils.py
+41
-38
colossalai/nn/layer/colossalai_layer/dropout.py
colossalai/nn/layer/colossalai_layer/dropout.py
+31
-30
No files found.
colossalai/nn/layer/colossalai_layer/_utils.py
View file @
61e68783
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch
import
Tensor
from
..parallel_2d._operation
import
split_batch_2d
from
..parallel_2d._operation
import
split_batch_2d
from
..parallel_2p5d._operation
import
split_batch_2p5d
from
..parallel_2p5d._operation
import
split_batch_2p5d
from
..parallel_3d._operation
import
split_batch_3d
from
..parallel_3d._operation
import
split_batch_3d
from
..utils
import
get_tensor_parallel_mode
from
..utils
import
get_tensor_parallel_mode
_parallel_split_batch
=
{
'2d'
:
split_batch_2d
,
'2.5d'
:
split_batch_2p5d
,
'3d'
:
split_batch_3d
}
_parallel_split_batch
=
{
'2d'
:
split_batch_2d
,
'2.5d'
:
split_batch_2p5d
,
'3d'
:
split_batch_3d
}
def
partition_batch
(
input_
)
->
Tensor
:
def
partition_batch
(
input_
)
->
Tensor
:
tensor_parallel_mode
=
get_tensor_parallel_mode
()
tensor_parallel_mode
=
get_tensor_parallel_mode
()
if
tensor_parallel_mode
in
_parallel_split_batch
:
if
tensor_parallel_mode
in
_parallel_split_batch
:
if
isinstance
(
input_
,
dict
):
if
isinstance
(
input_
,
dict
):
return
{
k
:
_parallel_split_batch
[
tensor_parallel_mode
](
v
)
for
k
,
v
in
input_
.
items
()}
return
{
k
:
_parallel_split_batch
[
tensor_parallel_mode
](
v
)
for
k
,
v
in
input_
.
items
()}
else
:
else
:
return
_parallel_split_batch
[
tensor_parallel_mode
](
input_
)
return
_parallel_split_batch
[
tensor_parallel_mode
](
input_
)
else
:
else
:
return
input_
return
input_
class
ColossalaiModule
(
nn
.
Module
):
class
ColossalaiModule
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
**
kwargs
):
def
__init__
(
self
,
module
:
nn
.
Module
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
# copy values
self
.
module
=
module
self
.
__dict__
=
module
.
__dict__
.
copy
()
for
k
,
v
in
kwargs
.
items
():
# copy methods
setattr
(
self
,
k
,
v
)
for
name
,
attr
in
module
.
__class__
.
__dict__
.
items
():
if
name
not
in
[
'__init__'
,
'forward'
]
and
callable
(
attr
):
def
__getattr__
(
self
,
name
:
str
):
setattr
(
self
,
name
,
getattr
(
module
,
name
))
if
name
==
'module'
:
self
.
_forward_func
=
module
.
forward
return
super
().
__getattr__
(
name
)
for
k
,
v
in
kwargs
.
items
():
elif
hasattr
(
self
.
module
,
name
):
setattr
(
self
,
k
,
v
)
return
getattr
(
self
.
module
,
name
)
elif
name
in
self
.
__dict__
:
def
forward
(
self
,
*
args
):
return
self
.
__dict__
[
name
]
return
self
.
_forward_func
(
*
args
)
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
).
__name__
,
name
))
def
forward
(
self
,
*
args
):
return
self
.
module
(
*
args
)
colossalai/nn/layer/colossalai_layer/dropout.py
View file @
61e68783
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
..parallel_1d
import
*
from
..utils
import
get_tensor_parallel_mode
from
..parallel_1d
import
*
from
._utils
import
ColossalaiModule
from
..utils
import
get_tensor_parallel_mode
from
._utils
import
ColossalaiModule
class
Dropout
(
ColossalaiModule
):
"""Dropout layer of colossalai.
class
Dropout
(
ColossalaiModule
):
"""Dropout layer of colossalai.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
Args:
inplace (bool, optional): whether to do dropout in-place, default to be False.
p (float, optional): probability of an element to be zeroed, defaults 0.5.
"""
inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
)
->
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
)
->
None
:
if
tensor_parallel
==
"1d"
:
tensor_parallel
=
get_tensor_parallel_mode
()
drop
=
Dropout1D
(
p
,
inplace
)
if
tensor_parallel
==
"1d"
:
else
:
drop
=
Dropout1D
(
p
,
inplace
)
drop
=
nn
.
Dropout
(
p
,
inplace
)
else
:
super
().
__init__
(
drop
,
tensor_parallel
=
tensor_parallel
)
drop
=
nn
.
Dropout
(
p
,
inplace
)
super
().
__init__
(
drop
,
tensor_parallel
=
tensor_parallel
)
def
forward
(
self
,
*
args
):
if
self
.
tensor_parallel
in
[
None
,
'1d'
]:
def
forward
(
self
,
*
args
):
return
self
.
_forward_func
(
*
args
)
if
self
.
tensor_parallel
in
[
None
,
'1d'
]:
else
:
return
super
().
forward
(
*
args
)
with
seed
(
ParallelMode
.
TENSOR
):
else
:
return
self
.
_forward_func
(
*
args
)
with
seed
(
ParallelMode
.
TENSOR
):
return
super
().
forward
(
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment