Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a7169297
Commit
a7169297
authored
Jan 22, 2021
by
Vijay Korthikanti
Browse files
Addressing review comments
parent
58edb19a
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
187 additions
and
46 deletions
+187
-46
megatron/arguments.py
megatron/arguments.py
+3
-1
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+145
-0
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+4
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+2
-2
megatron/model/vit_model.py
megatron/model/vit_model.py
+1
-1
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+4
-2
megatron/training.py
megatron/training.py
+19
-14
pretrain_vit.py
pretrain_vit.py
+7
-23
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+2
-2
No files found.
megatron/arguments.py
View file @
a7169297
...
@@ -362,7 +362,9 @@ def _add_training_args(parser):
...
@@ -362,7 +362,9 @@ def _add_training_args(parser):
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
help
=
'Optimizer function'
)
group
.
add_argument
(
'--dataloader_type'
,
type
=
str
,
default
=
'single'
,
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
return
parser
return
parser
...
...
megatron/data/data_
load
ers.py
→
megatron/data/data_
sampl
ers.py
View file @
a7169297
...
@@ -17,12 +17,12 @@
...
@@ -17,12 +17,12 @@
import
torch
import
torch
import
random
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
,
random_sample
=
False
):
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
"""Buld dataloader given an input dataset."""
"""Buld dataloader given an input dataset."""
if
dataset
is
None
:
if
dataset
is
None
:
...
@@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
...
@@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
args
=
get_args
()
args
=
get_args
()
# Megatron sampler
# Megatron sampler
batch_sampler
=
MegatronPretrainingSampler
(
if
args
.
dataloader_type
==
'single'
:
total_samples
=
len
(
dataset
),
batch_sampler
=
MegatronPretrainingSampler
(
consumed_samples
=
consumed_samples
,
total_samples
=
len
(
dataset
),
micro_batch_size
=
args
.
micro_batch_size
,
consumed_samples
=
consumed_samples
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
random_sample
=
random_sample
)
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
elif
args
.
dataloader_type
==
'cyclic'
:
batch_sampler
=
MegatronPretrainingRandomSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
else
:
raise
Exception
(
'{} dataloader type is not supported.'
.
format
(
args
.
dataloader_type
))
# Torch dataloader.
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
...
@@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
...
@@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
)
pin_memory
=
True
)
class
MegatronPretrainingSampler
:
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
,
random_sample
=
False
):
data_parallel_rank
,
data_parallel_size
):
# Keep a copy of input params for later use.
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
consumed_samples
=
consumed_samples
...
@@ -56,14 +65,13 @@ class MegatronPretrainingSampler:
...
@@ -56,14 +65,13 @@ class MegatronPretrainingSampler:
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
micro_batch_size
*
data_parallel_size
self
.
random_sample
=
random_sample
# Sanity checks.
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
assert
self
.
total_samples
>
0
,
\
'no sample to consume: {}'
.
format
(
self
.
total_samples
)
'no sample to consume: {}'
.
format
(
self
.
total_samples
)
#
assert self.consumed_samples < self.total_samples, \
assert
self
.
consumed_samples
<
self
.
total_samples
,
\
#
'no samples left to consume: {}, {}'.format(self.consumed_samples,
'no samples left to consume: {}, {}'
.
format
(
self
.
consumed_samples
,
#
self.total_samples)
self
.
total_samples
)
assert
self
.
micro_batch_size
>
0
assert
self
.
micro_batch_size
>
0
assert
data_parallel_size
>
0
assert
data_parallel_size
>
0
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
...
@@ -74,25 +82,64 @@ class MegatronPretrainingSampler:
...
@@ -74,25 +82,64 @@ class MegatronPretrainingSampler:
return
self
.
total_samples
return
self
.
total_samples
def
__iter__
(
self
):
def
__iter__
(
self
):
self
.
epoch
=
self
.
consumed_samples
//
self
.
total_samples
current_epoch_samples
=
self
.
consumed_samples
%
self
.
total_samples
if
self
.
random_sample
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
idx_range_total
=
\
torch
.
randperm
(
self
.
total_samples
,
generator
=
g
).
tolist
()
idx_range
=
idx_range_total
[
current_epoch_samples
:]
else
:
idx_range
=
range
(
current_epoch_samples
,
self
.
total_samples
)
batch
=
[]
batch
=
[]
# Last batch if not complete will be dropped.
# Last batch if not complete will be dropped.
for
idx
in
idx_
range
:
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
)
:
batch
.
append
(
idx
)
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
self
.
consumed_samples
+=
len
(
batch
)
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
yield
batch
[
start_idx
:
end_idx
]
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
batch
=
[]
self
.
consumed_samples
+=
len
(
batch
)
class
MegatronPretrainingRandomSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
):
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
'no sample to consume: {}'
.
format
(
self
.
total_samples
)
assert
self
.
micro_batch_size
>
0
assert
data_parallel_size
>
0
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
'data_parallel_rank should be smaller than data size: {}, '
\
'{}'
.
format
(
self
.
data_parallel_rank
,
data_parallel_size
)
def
__len__
(
self
):
return
self
.
total_samples
def
__iter__
(
self
):
self
.
epoch
=
self
.
consumed_samples
//
self
.
total_samples
current_epoch_samples
=
self
.
consumed_samples
%
self
.
total_samples
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
# data sharding and random sampling
bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_times_data_parallel_size
)
\
*
self
.
micro_batch_size
bucket_offset
=
current_epoch_samples
//
self
.
data_parallel_size
start_idx
=
self
.
data_parallel_rank
*
bucket_size
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
random_idx
=
torch
.
randperm
(
bucket_size
,
generator
=
g
).
tolist
()
idx_range
=
[
start_idx
+
x
for
x
in
random_idx
[
bucket_offset
:]]
batch
=
[]
# Last batch if not complete will be dropped.
for
idx
in
idx_range
:
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_size
:
self
.
consumed_samples
+=
self
.
micro_batch_times_data_parallel_size
yield
batch
batch
=
[]
self
.
consumed_samples
+=
self
.
total_samples
%
self
.
micro_batch_times_data_parallel_size
megatron/data/vit_dataset.py
View file @
a7169297
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
torch
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron.data.autoaugment
import
ImageNetPolicy
...
@@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
...
@@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.4
,
hue
=
0.1
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.4
,
hue
=
0.1
)
)
]
]
process
+=
[
ImageNetPolicy
(),
transforms
.
ToTensor
(),
normalize
]
fp16_t
=
transforms
.
ConvertImageDtype
(
torch
.
half
)
process
+=
[
ImageNetPolicy
(),
transforms
.
ToTensor
(),
normalize
,
fp16_t
]
transform_train
=
transforms
.
Compose
(
process
)
transform_train
=
transforms
.
Compose
(
process
)
train_data
=
datasets
.
ImageFolder
(
train_data
=
datasets
.
ImageFolder
(
root
=
train_data_path
,
transform
=
transform_train
root
=
train_data_path
,
transform
=
transform_train
...
@@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
...
@@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
transforms
.
CenterCrop
(
crop_size
),
transforms
.
CenterCrop
(
crop_size
),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
normalize
,
normalize
,
fp16_t
]
]
)
)
val_data
=
datasets
.
ImageFolder
(
val_data
=
datasets
.
ImageFolder
(
...
...
megatron/model/fused_softmax.py
View file @
a7169297
...
@@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
\
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
mask
is
not
None
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
...
@@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if
self
.
scale
is
not
None
:
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
else
input
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/vit_model.py
View file @
a7169297
...
@@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook(
...
@@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook(
class
VitModel
(
MegatronModule
):
class
VitModel
(
MegatronModule
):
"""
Bert Language m
odel."""
"""
Vision Transformer M
odel."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
):
def
__init__
(
self
,
num_classes
,
finetune
=
False
):
super
(
VitModel
,
self
).
__init__
()
super
(
VitModel
,
self
).
__init__
()
...
...
megatron/optimizer/__init__.py
View file @
a7169297
...
@@ -59,12 +59,14 @@ def get_megatron_optimizer(model):
...
@@ -59,12 +59,14 @@ def get_megatron_optimizer(model):
weight_decay
=
args
.
weight_decay
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
eps
=
args
.
adam_eps
)
else
:
elif
args
.
optimizer
==
'sgd'
:
assert
args
.
optimizer
==
'sgd'
optimizer
=
SGD
(
param_groups
,
optimizer
=
SGD
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
weight_decay
=
args
.
weight_decay
,
momentum
=
args
.
sgd_momentum
)
momentum
=
args
.
sgd_momentum
)
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
if
args
.
fp16
:
if
args
.
fp16
:
# Constant loss scale.
# Constant loss scale.
...
...
megatron/training.py
View file @
a7169297
...
@@ -46,7 +46,7 @@ from megatron.learning_rates import AnnealingLR
...
@@ -46,7 +46,7 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.data.data_
load
ers
import
build_pretraining_data_loader
from
megatron.data.data_
sampl
ers
import
build_pretraining_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
...
@@ -61,8 +61,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -61,8 +61,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider
,
model_provider
,
forward_step_func
,
forward_step_func
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{},
args_defaults
=
{}):
random_sample
=
False
):
"""Main training program.
"""Main training program.
This function will run the followings in the order provided:
This function will run the followings in the order provided:
...
@@ -117,8 +116,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -117,8 +116,7 @@ def pretrain(train_valid_test_dataset_provider,
timers
(
'train/valid/test data iterators'
).
start
()
timers
(
'train/valid/test data iterators'
).
start
()
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
,
train_valid_test_dataset_provider
)
random_sample
)
timers
(
'train/valid/test data iterators'
).
stop
()
timers
(
'train/valid/test data iterators'
).
stop
()
print_datetime
(
'after dataloaders are built'
)
print_datetime
(
'after dataloaders are built'
)
...
@@ -955,13 +953,13 @@ def evaluate_and_print_results(prefix, forward_step_func,
...
@@ -955,13 +953,13 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_last
(
'-'
*
length
)
print_rank_last
(
'-'
*
length
)
def
cyclic_iter
able
(
iter
able
):
def
cyclic_iter
(
iter
):
while
True
:
while
True
:
for
x
in
iter
able
:
for
x
in
iter
:
yield
x
yield
x
def
build_train_valid_test_data_iterators
(
def
build_train_valid_test_data_iterators
(
build_train_valid_test_datasets_provider
,
random_sample
=
False
):
build_train_valid_test_datasets_provider
):
"""XXX"""
"""XXX"""
args
=
get_args
()
args
=
get_args
()
...
@@ -1005,10 +1003,10 @@ def build_train_valid_test_data_iterators(
...
@@ -1005,10 +1003,10 @@ def build_train_valid_test_data_iterators(
# Build dataloders.
# Build dataloders.
train_dataloader
=
build_pretraining_data_loader
(
train_dataloader
=
build_pretraining_data_loader
(
train_ds
,
args
.
consumed_train_samples
,
random_sample
)
train_ds
,
args
.
consumed_train_samples
)
valid_dataloader
=
build_pretraining_data_loader
(
valid_dataloader
=
build_pretraining_data_loader
(
valid_ds
,
args
.
consumed_valid_samples
,
random_sample
)
valid_ds
,
args
.
consumed_valid_samples
)
test_dataloader
=
build_pretraining_data_loader
(
test_ds
,
0
,
random_sample
)
test_dataloader
=
build_pretraining_data_loader
(
test_ds
,
0
)
# Flags to know if we need to do training/validation/testing.
# Flags to know if we need to do training/validation/testing.
do_train
=
train_dataloader
is
not
None
and
args
.
train_iters
>
0
do_train
=
train_dataloader
is
not
None
and
args
.
train_iters
>
0
...
@@ -1028,19 +1026,26 @@ def build_train_valid_test_data_iterators(
...
@@ -1028,19 +1026,26 @@ def build_train_valid_test_data_iterators(
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
# Build iterators.
# Build iterators.
dl_type
=
args
.
dataloader_type
assert
dl_type
in
[
'single'
,
'cyclic'
]
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
train_data_iterator
=
iter
(
cyclic_iterable
(
train_dataloader
))
train_data_iterator
=
iter
(
train_dataloader
)
if
dl_type
==
'single'
\
else
iter
(
cyclic_iter
(
train_dataloader
))
else
:
else
:
train_data_iterator
=
None
train_data_iterator
=
None
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
valid_data_iterator
=
iter
(
cyclic_iterable
(
valid_dataloader
))
valid_data_iterator
=
iter
(
valid_dataloader
)
if
dl_type
==
'single'
\
else
iter
(
cyclic_iter
(
valid_dataloader
))
else
:
else
:
valid_data_iterator
=
None
valid_data_iterator
=
None
if
test_dataloader
is
not
None
:
if
test_dataloader
is
not
None
:
test_data_iterator
=
iter
(
cyclic_iterable
(
test_dataloader
))
test_data_iterator
=
iter
(
test_dataloader
)
if
dl_type
==
'single'
\
else
iter
(
cyclic_iter
(
test_dataloader
))
else
:
else
:
test_data_iterator
=
None
test_data_iterator
=
None
...
...
pretrain_vit.py
View file @
a7169297
...
@@ -23,7 +23,6 @@ from megatron.model import VitModel
...
@@ -23,7 +23,6 @@ from megatron.model import VitModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
...
@@ -33,43 +32,28 @@ def model_provider():
...
@@ -33,43 +32,28 @@ def model_provider():
model
=
VitModel
(
num_classes
=
args
.
num_classes
)
model
=
VitModel
(
num_classes
=
args
.
num_classes
)
return
model
return
model
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
"""Build the batch."""
"""Build the batch."""
data
=
next
(
data_iterator
)
# Items and their type.
# only data parallelism; no need for broadcast
keys
=
[
"image"
,
"label"
]
images
=
data
[
0
].
cuda
()
datatype
=
torch
.
half
labels
=
data
[
1
].
cuda
()
# Broadcast data.
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
dict_data
=
{}
dict_data
[
"image"
]
=
data
[
0
].
half
()
dict_data
[
"label"
]
=
data
[
1
].
half
()
data_b
=
mpu
.
broadcast_data
(
keys
,
dict_data
,
datatype
)
# Unpack.
images
=
data_b
[
"image"
]
labels
=
data_b
[
"label"
].
long
()
return
images
,
labels
return
images
,
labels
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
"""Forward step."""
"""Forward step."""
timers
=
get_timers
()
timers
=
get_timers
()
assert
input_tensor
is
None
assert
input_tensor
is
None
# Get the batch.
# Get the batch.
timers
(
"batch
generator"
).
start
()
timers
(
"batch
-
generator"
).
start
()
(
(
images
,
images
,
labels
,
labels
,
)
=
get_batch
(
data_iterator
)
)
=
get_batch
(
data_iterator
)
timers
(
"batch
generator"
).
stop
()
timers
(
"batch
-
generator"
).
stop
()
# Forward model. lm_labels
# Forward model. lm_labels
logits
=
model
(
images
).
contiguous
().
float
()
logits
=
model
(
images
).
contiguous
().
float
()
...
@@ -103,5 +87,5 @@ if __name__ == "__main__":
...
@@ -103,5 +87,5 @@ if __name__ == "__main__":
train_valid_test_datasets_provider
,
train_valid_test_datasets_provider
,
model_provider
,
model_provider
,
forward_step
,
forward_step
,
random_sample
=
True
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
)
tasks/vision/finetune_utils.py
View file @
a7169297
...
@@ -33,8 +33,8 @@ from megatron.utils import average_losses_across_data_parallel_group
...
@@ -33,8 +33,8 @@ from megatron.utils import average_losses_across_data_parallel_group
def
process_batch
(
batch
):
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
half
().
cuda
().
contiguous
()
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
long
().
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
return
images
,
labels
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment