Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f03bcb35
Unverified
Commit
f03bcb35
authored
Jan 04, 2022
by
ver217
Committed by
GitHub
Jan 04, 2022
Browse files
update vit example for new API (#98) (#99)
parent
d09a79ba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
10 deletions
+19
-10
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
...enet_data_parallel/dataloader/imagenet_dali_dataloader.py
+5
-5
examples/vit_b16_imagenet_data_parallel/mixup.py
examples/vit_b16_imagenet_data_parallel/mixup.py
+11
-2
examples/vit_b16_imagenet_data_parallel/train.py
examples/vit_b16_imagenet_data_parallel/train.py
+3
-3
No files found.
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
View file @
f03bcb35
...
...
@@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
img
=
lam
*
img
+
(
1
-
lam
)
*
img
[
idx
,
:]
label_a
,
label_b
=
label
,
label
[
idx
]
lam
=
torch
.
tensor
([
lam
],
device
=
img
.
device
,
dtype
=
img
.
dtype
)
label
=
(
label_a
,
label_b
,
lam
)
label
=
{
'targets_a'
:
label_a
,
'targets_b'
:
label_b
,
'
lam
'
:
lam
}
else
:
label
=
(
label
,
label
,
torch
.
ones
(
1
,
device
=
img
.
device
,
dtype
=
img
.
dtype
)
)
return
(
img
,
),
label
return
(
img
,
),
(
label
,)
label
=
{
'targets_a'
:
label
,
'targets_b'
:
label
,
'lam'
:
torch
.
ones
(
1
,
device
=
img
.
device
,
dtype
=
img
.
dtype
)
}
return
img
,
label
return
img
,
label
examples/vit_b16_imagenet_data_parallel/mixup.py
View file @
f03bcb35
import
torch.nn
as
nn
from
colossalai.registry
import
LOSSES
import
torch
@
LOSSES
.
register_module
class
MixupLoss
(
nn
.
Module
):
...
...
@@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
super
().
__init__
()
self
.
loss_fn
=
loss_fn_cls
()
def
forward
(
self
,
inputs
,
*
args
):
targets_a
,
targets_b
,
lam
=
args
def
forward
(
self
,
inputs
,
targets_a
,
targets_b
,
lam
):
return
lam
*
self
.
loss_fn
(
inputs
,
targets_a
)
+
(
1
-
lam
)
*
self
.
loss_fn
(
inputs
,
targets_b
)
class
MixupAccuracy
(
nn
.
Module
):
def
forward
(
self
,
logits
,
targets
):
targets
=
targets
[
'targets_a'
]
preds
=
torch
.
argmax
(
logits
,
dim
=-
1
)
correct
=
torch
.
sum
(
targets
==
preds
)
return
correct
examples/vit_b16_imagenet_data_parallel/train.py
View file @
f03bcb35
...
...
@@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from
colossalai.trainer
import
Trainer
,
hooks
from
colossalai.nn.lr_scheduler
import
LinearWarmupLR
from
dataloader.imagenet_dali_dataloader
import
DaliDataloader
from
mixup
import
MixupLoss
from
mixup
import
MixupLoss
,
MixupAccuracy
from
timm.models
import
vit_base_patch16_224
from
myhooks
import
TotalBatchsizeHook
...
...
@@ -62,7 +62,7 @@ def main():
port
=
args
.
port
,
backend
=
args
.
backend
)
# launch from torch
# launch from torch
# colossalai.launch_from_torch(config=args.config)
# get logger
...
...
@@ -96,7 +96,7 @@ def main():
# build hooks
hook_list
=
[
hooks
.
LossHook
(),
hooks
.
AccuracyHook
(
accuracy_func
=
Accuracy
()),
hooks
.
AccuracyHook
(
accuracy_func
=
Mixup
Accuracy
()),
hooks
.
LogMetricByEpochHook
(
logger
),
hooks
.
LRSchedulerHook
(
lr_scheduler
,
by_epoch
=
True
),
TotalBatchsizeHook
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment