Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f03bcb35
"...zh/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "85b2303b5506f9cef57bed571eedb186015a4b8c"
Unverified
Commit
f03bcb35
authored
Jan 04, 2022
by
ver217
Committed by
GitHub
Jan 04, 2022
Browse files
update vit example for new API (#98) (#99)
parent
d09a79ba
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
10 deletions
+19
-10
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
...enet_data_parallel/dataloader/imagenet_dali_dataloader.py
+5
-5
examples/vit_b16_imagenet_data_parallel/mixup.py
examples/vit_b16_imagenet_data_parallel/mixup.py
+11
-2
examples/vit_b16_imagenet_data_parallel/train.py
examples/vit_b16_imagenet_data_parallel/train.py
+3
-3
No files found.
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
View file @
f03bcb35
...
@@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
...
@@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
img
=
lam
*
img
+
(
1
-
lam
)
*
img
[
idx
,
:]
img
=
lam
*
img
+
(
1
-
lam
)
*
img
[
idx
,
:]
label_a
,
label_b
=
label
,
label
[
idx
]
label_a
,
label_b
=
label
,
label
[
idx
]
lam
=
torch
.
tensor
([
lam
],
device
=
img
.
device
,
dtype
=
img
.
dtype
)
lam
=
torch
.
tensor
([
lam
],
device
=
img
.
device
,
dtype
=
img
.
dtype
)
label
=
(
label_a
,
label_b
,
lam
)
label
=
{
'targets_a'
:
label_a
,
'targets_b'
:
label_b
,
'
lam
'
:
lam
}
else
:
else
:
label
=
(
label
,
label
,
torch
.
ones
(
label
=
{
'targets_a'
:
label
,
'targets_b'
:
label
,
1
,
device
=
img
.
device
,
dtype
=
img
.
dtype
)
)
'lam'
:
torch
.
ones
(
1
,
device
=
img
.
device
,
dtype
=
img
.
dtype
)
}
return
(
img
,
),
label
return
img
,
label
return
(
img
,
),
(
label
,)
return
img
,
label
examples/vit_b16_imagenet_data_parallel/mixup.py
View file @
f03bcb35
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.registry
import
LOSSES
from
colossalai.registry
import
LOSSES
import
torch
@
LOSSES
.
register_module
@
LOSSES
.
register_module
class
MixupLoss
(
nn
.
Module
):
class
MixupLoss
(
nn
.
Module
):
...
@@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
...
@@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
loss_fn
=
loss_fn_cls
()
self
.
loss_fn
=
loss_fn_cls
()
def
forward
(
self
,
inputs
,
*
args
):
def
forward
(
self
,
inputs
,
targets_a
,
targets_b
,
lam
):
targets_a
,
targets_b
,
lam
=
args
return
lam
*
self
.
loss_fn
(
inputs
,
targets_a
)
+
(
1
-
lam
)
*
self
.
loss_fn
(
inputs
,
targets_b
)
return
lam
*
self
.
loss_fn
(
inputs
,
targets_a
)
+
(
1
-
lam
)
*
self
.
loss_fn
(
inputs
,
targets_b
)
class
MixupAccuracy
(
nn
.
Module
):
def
forward
(
self
,
logits
,
targets
):
targets
=
targets
[
'targets_a'
]
preds
=
torch
.
argmax
(
logits
,
dim
=-
1
)
correct
=
torch
.
sum
(
targets
==
preds
)
return
correct
examples/vit_b16_imagenet_data_parallel/train.py
View file @
f03bcb35
...
@@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
...
@@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.nn.lr_scheduler
import
LinearWarmupLR
from
colossalai.nn.lr_scheduler
import
LinearWarmupLR
from
dataloader.imagenet_dali_dataloader
import
DaliDataloader
from
dataloader.imagenet_dali_dataloader
import
DaliDataloader
from
mixup
import
MixupLoss
from
mixup
import
MixupLoss
,
MixupAccuracy
from
timm.models
import
vit_base_patch16_224
from
timm.models
import
vit_base_patch16_224
from
myhooks
import
TotalBatchsizeHook
from
myhooks
import
TotalBatchsizeHook
...
@@ -96,7 +96,7 @@ def main():
...
@@ -96,7 +96,7 @@ def main():
# build hooks
# build hooks
hook_list
=
[
hook_list
=
[
hooks
.
LossHook
(),
hooks
.
LossHook
(),
hooks
.
AccuracyHook
(
accuracy_func
=
Accuracy
()),
hooks
.
AccuracyHook
(
accuracy_func
=
Mixup
Accuracy
()),
hooks
.
LogMetricByEpochHook
(
logger
),
hooks
.
LogMetricByEpochHook
(
logger
),
hooks
.
LRSchedulerHook
(
lr_scheduler
,
by_epoch
=
True
),
hooks
.
LRSchedulerHook
(
lr_scheduler
,
by_epoch
=
True
),
TotalBatchsizeHook
(),
TotalBatchsizeHook
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment