Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2cfe685b
Unverified
Commit
2cfe685b
authored
Dec 20, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 20, 2022
Browse files
[exmaple] add vit missing functions (#2154)
parent
a7d95b70
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
4 deletions
+61
-4
examples/images/vit/test_vit.py
examples/images/vit/test_vit.py
+34
-2
examples/images/vit/vit.py
examples/images/vit/vit.py
+27
-2
No files found.
examples/images/vit/test_vit.py
View file @
2cfe685b
import
os
import
random
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
utils.util
import
set_seed
,
tensor_equal
,
tensor_shard_equal
from
vit
import
get_training_components
from
vit
import
get_training_components
import
colossalai
import
colossalai
from
colossalai.context
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.tensor
import
ColoParameter
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
,
ShardSpec
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
def
set_seed
(
seed
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
def
tensor_equal
(
A
,
B
):
return
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-1
)
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
):
assert
tensor
.
ndim
==
shard
.
ndim
if
tensor
.
shape
==
shard
.
shape
:
return
tensor_equal
(
tensor
,
shard
)
else
:
dims_not_eq
=
torch
.
nonzero
(
torch
.
tensor
(
tensor
.
shape
)
!=
torch
.
tensor
(
shard
.
shape
))
if
dims_not_eq
.
numel
()
==
1
:
# 1D shard
dim
=
dims_not_eq
.
item
()
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
)
else
:
raise
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# But for other layers, it's 1d_col split.
# But for other layers, it's 1d_col split.
# Layernorm is not supported for now.
# Layernorm is not supported for now.
...
...
examples/images/vit/vit.py
View file @
2cfe685b
from
abc
import
ABC
,
abstractmethod
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
utils.dummy_data_generator
import
DummyDataGenerator
from
transformers
import
ViTConfig
,
ViTForImageClassification
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
transformers
import
ViTConfig
,
ViTForImageClassification
class
DummyDataGenerator
(
ABC
):
def
__init__
(
self
,
length
=
10
):
self
.
length
=
length
@
abstractmethod
def
generate
(
self
):
pass
def
__iter__
(
self
):
self
.
step
=
0
return
self
def
__next__
(
self
):
if
self
.
step
<
self
.
length
:
self
.
step
+=
1
return
self
.
generate
()
else
:
raise
StopIteration
def
__len__
(
self
):
return
self
.
length
class
DummyDataLoader
(
DummyDataGenerator
):
class
DummyDataLoader
(
DummyDataGenerator
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment