Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1cd3650d
Commit
1cd3650d
authored
Feb 01, 2022
by
Vijay Korthikanti
Committed by
Sangkug Lym
Feb 15, 2022
Browse files
more minor fixes
parent
6f3bf9c0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
19 deletions
+15
-19
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+2
-11
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+0
-2
megatron/model/vision/esvit_swin_backbone.py
megatron/model/vision/esvit_swin_backbone.py
+4
-3
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+4
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+3
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+2
-1
No files found.
megatron/data/vit_dataset.py
View file @
1cd3650d
...
...
@@ -206,9 +206,9 @@ class DinoTransform(object):
normalize
])
# transformation for the local small crops
self
.
local_crops_number
=
args
.
local_crops_number
self
.
local_crops_number
=
args
.
dino_
local_crops_number
self
.
local_transform
=
T
.
Compose
([
T
.
RandomResizedCrop
(
args
.
local_img_size
,
T
.
RandomResizedCrop
(
args
.
dino_
local_img_size
,
scale
=
(
0.05
,
scale_const
),
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
...
...
@@ -218,12 +218,6 @@ class DinoTransform(object):
def
__call__
(
self
,
image
):
crops
=
[]
args
=
get_args
()
if
args
.
street_data
:
crop_transform
=
T
.
RandomCrop
(
300
)
image
=
crop_transform
(
image
)
crops
.
append
(
self
.
global_transform1
(
image
))
crops
.
append
(
self
.
global_transform2
(
image
))
for
_
in
range
(
self
.
local_crops_number
):
...
...
@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
raise
Exception
(
'{} vit pretraining type is not supported.'
.
format
(
args
.
vit_pretraining_type
))
train_transform
=
ClassificationTransform
(
image_size
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
# training dataset
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
train_data
=
ImageFolder
(
...
...
megatron/model/vision/dino.py
View file @
1cd3650d
...
...
@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.utils
import
print_tensor_min_max_norm
as
pt
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
from
megatron.model.vision.av_cam_trunk
import
get_av_cam_trunk
class
DINOLoss
(
torch
.
nn
.
Module
):
...
...
megatron/model/vision/esvit_swin_backbone.py
View file @
1cd3650d
...
...
@@ -14,7 +14,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
torch.distributed
as
dist
from
megatron.model.vision.utils
import
DropPath
,
trunc_normal_
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
import
numpy
as
np
...
...
@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module):
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
if
args
.
swin_type
==
"tiny"
:
if
args
.
swin_
backbone_
type
==
"tiny"
:
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
elif
args
.
swin_type
==
'h3'
:
elif
args
.
swin_
backbone_
type
==
'h3'
:
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
...
...
megatron/model/vision/vit_backbone.py
View file @
1cd3650d
...
...
@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
pre_process
=
True
,
post_process
=
True
,
class_token
=
True
,
single_token_output
=
False
):
single_token_output
=
False
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
...
...
@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
single_token_output
=
single_token_output
self
.
drop_path_rate
=
drop_path_rate
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
...
...
@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
drop_path_rate
=
self
.
drop_path_rate
)
def
set_input_tensor
(
self
,
input_tensor
):
...
...
pretrain_vision_dino.py
View file @
1cd3650d
...
...
@@ -23,13 +23,14 @@ import torch.distributed as dist
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.
contrastive
import
DINOPretrainModel
from
megatron.model.vision.
dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
,
unwrap_model
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
...
...
@@ -116,6 +117,7 @@ if __name__ == "__main__":
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
...
...
pretrain_vision_inpaint.py
View file @
1cd3650d
...
...
@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.vision.metrics
import
SSIM
,
PSNR
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
...
...
@@ -143,6 +143,7 @@ if __name__ == "__main__":
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
process_non_loss_data
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment