Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1cd3650d
Commit
1cd3650d
authored
Feb 01, 2022
by
Vijay Korthikanti
Committed by
Sangkug Lym
Feb 15, 2022
Browse files
more minor fixes
parent
6f3bf9c0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
19 deletions
+15
-19
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+2
-11
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+0
-2
megatron/model/vision/esvit_swin_backbone.py
megatron/model/vision/esvit_swin_backbone.py
+4
-3
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+4
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+3
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+2
-1
No files found.
megatron/data/vit_dataset.py
View file @
1cd3650d
...
@@ -206,9 +206,9 @@ class DinoTransform(object):
...
@@ -206,9 +206,9 @@ class DinoTransform(object):
normalize
normalize
])
])
# transformation for the local small crops
# transformation for the local small crops
self
.
local_crops_number
=
args
.
local_crops_number
self
.
local_crops_number
=
args
.
dino_
local_crops_number
self
.
local_transform
=
T
.
Compose
([
self
.
local_transform
=
T
.
Compose
([
T
.
RandomResizedCrop
(
args
.
local_img_size
,
T
.
RandomResizedCrop
(
args
.
dino_
local_img_size
,
scale
=
(
0.05
,
scale_const
),
scale
=
(
0.05
,
scale_const
),
interpolation
=
Image
.
BICUBIC
),
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
flip_and_color_jitter
,
...
@@ -218,12 +218,6 @@ class DinoTransform(object):
...
@@ -218,12 +218,6 @@ class DinoTransform(object):
def
__call__
(
self
,
image
):
def
__call__
(
self
,
image
):
crops
=
[]
crops
=
[]
args
=
get_args
()
if
args
.
street_data
:
crop_transform
=
T
.
RandomCrop
(
300
)
image
=
crop_transform
(
image
)
crops
.
append
(
self
.
global_transform1
(
image
))
crops
.
append
(
self
.
global_transform1
(
image
))
crops
.
append
(
self
.
global_transform2
(
image
))
crops
.
append
(
self
.
global_transform2
(
image
))
for
_
in
range
(
self
.
local_crops_number
):
for
_
in
range
(
self
.
local_crops_number
):
...
@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
...
@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
raise
Exception
(
'{} vit pretraining type is not supported.'
.
format
(
raise
Exception
(
'{} vit pretraining type is not supported.'
.
format
(
args
.
vit_pretraining_type
))
args
.
vit_pretraining_type
))
train_transform
=
ClassificationTransform
(
image_size
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
# training dataset
# training dataset
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
train_data
=
ImageFolder
(
train_data
=
ImageFolder
(
...
...
megatron/model/vision/dino.py
View file @
1cd3650d
...
@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0
...
@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.module
import
MegatronModule
from
megatron.utils
import
print_tensor_min_max_norm
as
pt
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
from
megatron.model.vision.av_cam_trunk
import
get_av_cam_trunk
class
DINOLoss
(
torch
.
nn
.
Module
):
class
DINOLoss
(
torch
.
nn
.
Module
):
...
...
megatron/model/vision/esvit_swin_backbone.py
View file @
1cd3650d
...
@@ -14,7 +14,8 @@ import torch.nn as nn
...
@@ -14,7 +14,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
megatron.model.vision.utils
import
DropPath
,
trunc_normal_
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
import
numpy
as
np
import
numpy
as
np
...
@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module):
...
@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module):
def
get_swin
(
is_teacher
=
False
):
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
args
=
get_args
()
if
args
.
swin_type
==
"tiny"
:
if
args
.
swin_
backbone_
type
==
"tiny"
:
embed_dim
=
96
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
drop_path_rate
=
0.1
elif
args
.
swin_type
==
'h3'
:
elif
args
.
swin_
backbone_
type
==
'h3'
:
embed_dim
=
384
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
num_heads
=
[
6
,
12
,
24
,
48
]
...
...
megatron/model/vision/vit_backbone.py
View file @
1cd3650d
...
@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
...
@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
post_process
=
True
,
class_token
=
True
,
class_token
=
True
,
single_token_output
=
False
):
single_token_output
=
False
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
...
@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
...
@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
self
.
img_w
=
args
.
img_w
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
single_token_output
=
single_token_output
self
.
single_token_output
=
single_token_output
self
.
drop_path_rate
=
drop_path_rate
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
...
@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
...
@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
self
.
scaled_init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
drop_path_rate
=
self
.
drop_path_rate
)
)
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
...
...
pretrain_vision_dino.py
View file @
1cd3650d
...
@@ -23,13 +23,14 @@ import torch.distributed as dist
...
@@ -23,13 +23,14 @@ import torch.distributed as dist
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.
contrastive
import
DINOPretrainModel
from
megatron.model.vision.
dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
from
megatron.model.vision.knn_monitor
import
knn_predict
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
,
unwrap_model
from
megatron.utils
import
average_losses_across_data_parallel_group
,
unwrap_model
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
...
@@ -116,6 +117,7 @@ if __name__ == "__main__":
...
@@ -116,6 +117,7 @@ if __name__ == "__main__":
pretrain
(
pretrain
(
train_valid_test_datasets_provider
,
train_valid_test_datasets_provider
,
model_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
)
...
...
pretrain_vision_inpaint.py
View file @
1cd3650d
...
@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel
...
@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.vision.metrics
import
SSIM
,
PSNR
from
tasks.vision.metrics
import
SSIM
,
PSNR
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
...
@@ -143,6 +143,7 @@ if __name__ == "__main__":
...
@@ -143,6 +143,7 @@ if __name__ == "__main__":
pretrain
(
pretrain
(
train_valid_test_datasets_provider
,
train_valid_test_datasets_provider
,
model_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
forward_step
,
process_non_loss_data
,
process_non_loss_data
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment