Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3b29f0e7
Commit
3b29f0e7
authored
Jan 13, 2022
by
Vijay Korthikanti
Browse files
minor fixes
parent
7a77abd9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2 additions
and
4 deletions
+2
-4
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+1
-1
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+1
-1
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+0
-2
No files found.
megatron/data/data_samplers.py
View file @
3b29f0e7
...
@@ -164,7 +164,7 @@ class MegatronPretrainingRandomSampler:
...
@@ -164,7 +164,7 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
if
isinstance
(
dataset
,
RandomSeedDataset
):
if
isinstance
(
self
.
dataset
,
RandomSeedDataset
):
self
.
dataset
.
set_epoch
(
self
.
epoch
)
self
.
dataset
.
set_epoch
(
self
.
epoch
)
# data sharding and random sampling
# data sharding and random sampling
...
...
megatron/data/vit_dataset.py
View file @
3b29f0e7
...
@@ -35,7 +35,7 @@ class ClassificationTransform():
...
@@ -35,7 +35,7 @@ class ClassificationTransform():
T
.
ColorJitter
(
0.4
,
0.4
,
0.4
,
0.1
),
T
.
ColorJitter
(
0.4
,
0.4
,
0.4
,
0.1
),
ImageNetPolicy
(),
ImageNetPolicy
(),
T
.
ToTensor
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
Normalize
(
(
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)
),
T
.
ConvertImageDtype
(
self
.
data_type
)
T
.
ConvertImageDtype
(
self
.
data_type
)
])
])
else
:
else
:
...
...
megatron/model/vision/classification.py
View file @
3b29f0e7
...
@@ -19,8 +19,6 @@ import torch
...
@@ -19,8 +19,6 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.module
import
MegatronModule
from
megatron.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
class
VitClassificationModel
(
MegatronModule
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment