Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6f3bf9c0
Commit
6f3bf9c0
authored
Feb 01, 2022
by
Vijay Korthikanti
Committed by
Sangkug Lym
Feb 15, 2022
Browse files
minor fixes
parent
48c2a144
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
9 additions
and
9 deletions
+9
-9
megatron/arguments.py
megatron/arguments.py
+1
-3
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+1
-1
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+1
-1
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+2
-1
megatron/training.py
megatron/training.py
+2
-2
pretrain_vision_classify.py
pretrain_vision_classify.py
+2
-1
No files found.
megatron/arguments.py
View file @
6f3bf9c0
...
@@ -880,10 +880,9 @@ def _add_vision_args(parser):
...
@@ -880,10 +880,9 @@ def _add_vision_args(parser):
group
.
add_argument
(
'--head-lr-mult'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--head-lr-mult'
,
type
=
float
,
default
=
1.0
,
help
=
'learning rate multiplier for head during finetuning'
)
help
=
'learning rate multiplier for head during finetuning'
)
# pretraining type and backbone selection`
# pretraining type and backbone selection`
group
.
add_argument
(
'--vision-pretraining-type'
,
type
=
str
,
default
=
'classify'
,
group
.
add_argument
(
'--vision-pretraining-type'
,
type
=
str
,
default
=
'classify'
,
choices
=
[
'classify'
,
'inpaint'
,
'
contrast
'
],
choices
=
[
'classify'
,
'inpaint'
,
'
dino
'
],
help
=
'pretraining objectives'
)
help
=
'pretraining objectives'
)
group
.
add_argument
(
'--vision-backbone-type'
,
type
=
str
,
default
=
'vit'
,
group
.
add_argument
(
'--vision-backbone-type'
,
type
=
str
,
default
=
'vit'
,
choices
=
[
'vit'
,
'mit'
,
'swin'
],
choices
=
[
'vit'
,
'mit'
,
'swin'
],
...
@@ -898,7 +897,6 @@ def _add_vision_args(parser):
...
@@ -898,7 +897,6 @@ def _add_vision_args(parser):
help
=
'mask types'
)
help
=
'mask types'
)
group
.
add_argument
(
'--mask-factor'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--mask-factor'
,
type
=
float
,
default
=
1.0
,
help
=
'mask size scaling parameter'
)
help
=
'mask size scaling parameter'
)
# dino arguments
# dino arguments
group
.
add_argument
(
'--iter-per-epoch'
,
type
=
int
,
default
=
1250
,
group
.
add_argument
(
'--iter-per-epoch'
,
type
=
int
,
default
=
1250
,
...
...
megatron/data/vit_dataset.py
View file @
6f3bf9c0
...
@@ -251,7 +251,7 @@ def build_train_valid_datasets(data_path, image_size=224):
...
@@ -251,7 +251,7 @@ def build_train_valid_datasets(data_path, image_size=224):
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
# training dataset
# training dataset
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
#TODO VIJAY
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
train_data
=
ImageFolder
(
train_data
=
ImageFolder
(
root
=
train_data_path
,
root
=
train_data_path
,
transform
=
train_transform
,
transform
=
train_transform
,
...
...
megatron/model/vision/classification.py
View file @
6f3bf9c0
...
@@ -68,7 +68,7 @@ class VitClassificationModel(MegatronModule):
...
@@ -68,7 +68,7 @@ class VitClassificationModel(MegatronModule):
class
MitClassificationModel
(
MegatronModule
):
class
MitClassificationModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
"""Mix vision Transformer Model."""
def
__init__
(
self
,
num_classes
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
):
super
(
MitClassificationModel
,
self
).
__init__
()
super
(
MitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
...
megatron/model/vision/mit_backbone.py
View file @
6f3bf9c0
...
@@ -8,7 +8,8 @@ import torch
...
@@ -8,7 +8,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron.model.vision.utils
import
DropPath
,
trunc_normal_
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
...
...
megatron/training.py
View file @
6f3bf9c0
...
@@ -714,7 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -714,7 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
update_num_microbatches
(
args
.
consumed_train_samples
)
args
.
curr_iteration
=
iteration
args
.
curr_iteration
=
iteration
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
...
@@ -804,7 +804,7 @@ def evaluate(forward_step_func,
...
@@ -804,7 +804,7 @@ def evaluate(forward_step_func,
"""Evaluation."""
"""Evaluation."""
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_pretraining_type
==
"
contrast
"
:
if
args
.
vision_pretraining_type
==
"
dino
"
:
args
.
knn_features
=
compute_feature_bank
(
model
)
args
.
knn_features
=
compute_feature_bank
(
model
)
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
...
...
pretrain_vision_classify.py
View file @
6f3bf9c0
...
@@ -30,14 +30,15 @@ from megatron.utils import average_losses_across_data_parallel_group
...
@@ -30,14 +30,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
"building VIT model ..."
)
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
if
args
.
vision_backbone_type
==
'vit'
:
print_rank_0
(
"building VIT model ..."
)
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
)
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
elif
args
.
vision_backbone_type
==
'mit'
:
print_rank_0
(
"building MIT model ..."
)
model
=
MitClassificationModel
(
num_classes
=
args
.
num_classes
,
model
=
MitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
)
post_process
=
post_process
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment