Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
megatron-lm
Commits
7c19b3a8
Commit
7c19b3a8
authored
Sep 26, 2024
by
wangsen
Browse files
Initial commit
parents
Pipeline
#1721
failed with stages
in 0 seconds
Changes
409
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1959 additions
and
0 deletions
+1959
-0
megatron/legacy/data/__pycache__/__init__.cpython-310.pyc
megatron/legacy/data/__pycache__/__init__.cpython-310.pyc
+0
-0
megatron/legacy/data/__pycache__/data_samplers.cpython-310.pyc
...ron/legacy/data/__pycache__/data_samplers.cpython-310.pyc
+0
-0
megatron/legacy/data/autoaugment.py
megatron/legacy/data/autoaugment.py
+320
-0
megatron/legacy/data/biencoder_dataset_utils.py
megatron/legacy/data/biencoder_dataset_utils.py
+209
-0
megatron/legacy/data/data_samplers.py
megatron/legacy/data/data_samplers.py
+192
-0
megatron/legacy/data/dataset_utils.py
megatron/legacy/data/dataset_utils.py
+726
-0
megatron/legacy/data/ict_dataset.py
megatron/legacy/data/ict_dataset.py
+156
-0
megatron/legacy/data/image_folder.py
megatron/legacy/data/image_folder.py
+302
-0
megatron/legacy/data/multimodal_dataset.py
megatron/legacy/data/multimodal_dataset.py
+54
-0
No files found.
Too many changes to show.
To preserve performance only
409 of 409+
files are displayed.
Plain diff
Email patch
megatron/legacy/data/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
7c19b3a8
File added
megatron/legacy/data/__pycache__/data_samplers.cpython-310.pyc
0 → 100644
View file @
7c19b3a8
File added
megatron/legacy/data/autoaugment.py
0 → 100644
View file @
7c19b3a8
"""AutoAugment data augmentation policy for ImageNet.
-- Begin license text.
MIT License
Copyright (c) 2018 Philip Popien
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-- End license text.
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
policies.
Reference:
[1] https://arxiv.org/abs/1805.09501
"""
import
random
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
ImageEnhance
from
PIL
import
ImageOps
_MAX_LEVEL
=
10
# Maximum integer strength of an augmentation, if applicable.
class
ImageNetPolicy
:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def
__init__
(
self
,
fillcolor
=
(
128
,
128
,
128
)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self
.
policies
=
[
SubPolicy
(
"posterize"
,
0.4
,
8
,
"rotate"
,
0.6
,
9
,
fillcolor
),
SubPolicy
(
"solarize"
,
0.6
,
5
,
"autocontrast"
,
0.6
,
5
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.8
,
8
,
"equalize"
,
0.6
,
3
,
fillcolor
),
SubPolicy
(
"posterize"
,
0.6
,
7
,
"posterize"
,
0.6
,
6
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.4
,
7
,
"solarize"
,
0.2
,
4
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.4
,
4
,
"rotate"
,
0.8
,
8
,
fillcolor
),
SubPolicy
(
"solarize"
,
0.6
,
3
,
"equalize"
,
0.6
,
7
,
fillcolor
),
SubPolicy
(
"posterize"
,
0.8
,
5
,
"equalize"
,
1.0
,
2
,
fillcolor
),
SubPolicy
(
"rotate"
,
0.2
,
3
,
"solarize"
,
0.6
,
8
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.6
,
8
,
"posterize"
,
0.4
,
6
,
fillcolor
),
SubPolicy
(
"rotate"
,
0.8
,
8
,
"color"
,
0.4
,
0
,
fillcolor
),
SubPolicy
(
"rotate"
,
0.4
,
9
,
"equalize"
,
0.6
,
2
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.0
,
7
,
"equalize"
,
0.8
,
8
,
fillcolor
),
SubPolicy
(
"invert"
,
0.6
,
4
,
"equalize"
,
1.0
,
8
,
fillcolor
),
SubPolicy
(
"color"
,
0.6
,
4
,
"contrast"
,
1.0
,
8
,
fillcolor
),
SubPolicy
(
"rotate"
,
0.8
,
8
,
"color"
,
1.0
,
2
,
fillcolor
),
SubPolicy
(
"color"
,
0.8
,
8
,
"solarize"
,
0.8
,
7
,
fillcolor
),
SubPolicy
(
"sharpness"
,
0.4
,
7
,
"invert"
,
0.6
,
8
,
fillcolor
),
SubPolicy
(
"shearX"
,
0.6
,
5
,
"equalize"
,
1.0
,
9
,
fillcolor
),
SubPolicy
(
"color"
,
0.4
,
0
,
"equalize"
,
0.6
,
3
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.4
,
7
,
"solarize"
,
0.2
,
4
,
fillcolor
),
SubPolicy
(
"solarize"
,
0.6
,
5
,
"autocontrast"
,
0.6
,
5
,
fillcolor
),
SubPolicy
(
"invert"
,
0.6
,
4
,
"equalize"
,
1.0
,
8
,
fillcolor
),
SubPolicy
(
"color"
,
0.6
,
4
,
"contrast"
,
1.0
,
8
,
fillcolor
),
SubPolicy
(
"equalize"
,
0.8
,
8
,
"equalize"
,
0.6
,
3
,
fillcolor
),
]
def
__call__
(
self
,
img
):
"""Define call method for ImageNetPolicy class."""
policy_idx
=
random
.
randint
(
0
,
len
(
self
.
policies
)
-
1
)
return
self
.
policies
[
policy_idx
](
img
)
def
__repr__
(
self
):
"""Define repr method for ImageNetPolicy class."""
return
"ImageNetPolicy"
class
SubPolicy
:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def
__init__
(
self
,
operation1
,
probability1
,
magnitude_idx1
,
operation2
,
probability2
,
magnitude_idx2
,
fillcolor
,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops
=
[
"shearX"
,
"shearY"
,
"translateX"
,
"translateY"
,
"rotate"
,
"color"
,
"posterize"
,
"solarize"
,
"contrast"
,
"sharpness"
,
"brightness"
,
"autocontrast"
,
"equalize"
,
"invert"
,
]
assert
(
operation1
in
supported_ops
)
and
(
operation2
in
supported_ops
),
"SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert
(
0.0
<=
probability1
<=
1.0
and
0.0
<=
probability2
<=
1.0
),
"SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert
(
isinstance
(
magnitude_idx1
,
int
)
and
0
<=
magnitude_idx1
<=
10
),
"SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert
(
isinstance
(
magnitude_idx2
,
int
)
and
0
<=
magnitude_idx2
<=
10
),
"SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels
=
_MAX_LEVEL
+
1
ranges
=
{
"shearX"
:
np
.
linspace
(
0
,
0.3
,
num_levels
),
"shearY"
:
np
.
linspace
(
0
,
0.3
,
num_levels
),
"translateX"
:
np
.
linspace
(
0
,
150
/
331
,
num_levels
),
"translateY"
:
np
.
linspace
(
0
,
150
/
331
,
num_levels
),
"rotate"
:
np
.
linspace
(
0
,
30
,
num_levels
),
"color"
:
np
.
linspace
(
0.0
,
0.9
,
num_levels
),
"posterize"
:
np
.
round
(
np
.
linspace
(
8
,
4
,
num_levels
),
0
).
astype
(
np
.
int32
),
"solarize"
:
np
.
linspace
(
256
,
0
,
num_levels
),
# range [0, 256]
"contrast"
:
np
.
linspace
(
0.0
,
0.9
,
num_levels
),
"sharpness"
:
np
.
linspace
(
0.0
,
0.9
,
num_levels
),
"brightness"
:
np
.
linspace
(
0.0
,
0.9
,
num_levels
),
"autocontrast"
:
[
0
]
*
num_levels
,
# This augmentation doesn't use magnitude parameter.
"equalize"
:
[
0
]
*
num_levels
,
# This augmentation doesn't use magnitude parameter.
"invert"
:
[
0
]
*
num_levels
,
# This augmentation doesn't use magnitude parameter.
}
def
rotate_with_fill
(
img
,
magnitude
):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated
=
img
.
convert
(
"RGBA"
).
rotate
(
magnitude
)
rotated_filled
=
Image
.
composite
(
rotated
,
Image
.
new
(
"RGBA"
,
rotated
.
size
,
(
128
,)
*
4
),
rotated
)
return
rotated_filled
.
convert
(
img
.
mode
)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict
=
{
"shearX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
magnitude
*
random
.
choice
([
-
1
,
1
]),
0
,
0
,
1
,
0
),
Image
.
BICUBIC
,
fillcolor
=
fillcolor
,
),
"shearY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
magnitude
*
random
.
choice
([
-
1
,
1
]),
1
,
0
),
Image
.
BICUBIC
,
fillcolor
=
fillcolor
,
),
"translateX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
magnitude
*
img
.
size
[
0
]
*
random
.
choice
([
-
1
,
1
]),
0
,
1
,
0
,
),
fillcolor
=
fillcolor
,
),
"translateY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
0
,
1
,
magnitude
*
img
.
size
[
1
]
*
random
.
choice
([
-
1
,
1
]),
),
fillcolor
=
fillcolor
,
),
"rotate"
:
lambda
img
,
magnitude
:
rotate_with_fill
(
img
,
magnitude
),
"color"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Color
(
img
).
enhance
(
1
+
magnitude
*
random
.
choice
([
-
1
,
1
])
),
"posterize"
:
lambda
img
,
magnitude
:
ImageOps
.
posterize
(
img
,
magnitude
),
"solarize"
:
lambda
img
,
magnitude
:
ImageOps
.
solarize
(
img
,
magnitude
),
"contrast"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Contrast
(
img
).
enhance
(
1
+
magnitude
*
random
.
choice
([
-
1
,
1
])),
"sharpness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Sharpness
(
img
).
enhance
(
1
+
magnitude
*
random
.
choice
([
-
1
,
1
])),
"brightness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Brightness
(
img
).
enhance
(
1
+
magnitude
*
random
.
choice
([
-
1
,
1
])),
"autocontrast"
:
lambda
img
,
magnitude
:
ImageOps
.
autocontrast
(
img
),
"equalize"
:
lambda
img
,
magnitude
:
ImageOps
.
equalize
(
img
),
"invert"
:
lambda
img
,
magnitude
:
ImageOps
.
invert
(
img
),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self
.
probability1
=
probability1
self
.
operation1
=
func_dict
[
operation1
]
self
.
magnitude1
=
ranges
[
operation1
][
magnitude_idx1
]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self
.
probability2
=
probability2
self
.
operation2
=
func_dict
[
operation2
]
self
.
magnitude2
=
ranges
[
operation2
][
magnitude_idx2
]
def
__call__
(
self
,
img
):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if
random
.
random
()
<
self
.
probability1
:
img
=
self
.
operation1
(
img
,
self
.
magnitude1
)
# Randomly apply operation 2.
if
random
.
random
()
<
self
.
probability2
:
img
=
self
.
operation2
(
img
,
self
.
magnitude2
)
return
img
megatron/legacy/data/biencoder_dataset_utils.py
0 → 100644
View file @
7c19b3a8
import
os
import
time
import
numpy
as
np
import
torch
from
megatron.training
import
get_args
,
get_tokenizer
,
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.legacy.data.dataset_utils
import
create_masked_lm_predictions
,
\
pad_and_convert_to_numpy
from
megatron.legacy.data.data_samplers
import
MegatronPretrainingSampler
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
num_workers
=
args
.
num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
0
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
drop_last
=
False
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_mask'
,
'context_tokens'
,
'context_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_mask
=
data_b
[
'query_mask'
]
<
0.5
context_tokens
=
data_b
[
'context_tokens'
].
long
()
context_mask
=
data_b
[
'context_mask'
]
<
0.5
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
block_indices
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
for
s
in
str_list
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
class
BlockSampleData
(
object
):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def
__init__
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
self
.
start_idx
=
start_idx
self
.
end_idx
=
end_idx
self
.
doc_idx
=
doc_idx
self
.
block_idx
=
block_idx
def
as_array
(
self
):
return
np
.
array
([
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
]).
astype
(
np
.
int64
)
def
as_tuple
(
self
):
return
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
class
BlockSamplesMapping
(
object
):
def
__init__
(
self
,
mapping_array
):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get the data associated with an indexed sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
use_one_sent_docs
:
indexmap_filename
+=
'_1sentok'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
mpu
.
get_data_parallel_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
block_dataset
.
document_indices
.
dtype
==
np
.
int64
assert
block_dataset
.
sequence_lengths
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
from
megatron.core.datasets
import
helpers
mapping_array
=
helpers
.
build_blocks_mapping
(
block_dataset
.
document_indices
,
block_dataset
.
sequence_lengths
,
title_dataset
.
sequence_lengths
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
,
use_one_sent_docs
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
mapping_array
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
tensor
([
1
],
dtype
=
torch
.
long
,
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
mapping_array
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
samples_mapping
=
BlockSamplesMapping
(
mapping_array
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
mapping_array
.
shape
[
0
]))
return
samples_mapping
megatron/legacy/data/data_samplers.py
0 → 100644
View file @
7c19b3a8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Dataloaders."""
import
random
import
torch
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron.training
import
get_args
from
megatron.core
import
mpu
def
build_pretraining_data_loader
(
dataset
,
consumed_samples
):
"""Build dataloader given an input dataset."""
if
dataset
is
None
:
return
None
args
=
get_args
()
# Megatron sampler
if
args
.
dataloader_type
==
'single'
:
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
elif
args
.
dataloader_type
==
'cyclic'
:
batch_sampler
=
MegatronPretrainingRandomSampler
(
dataset
,
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
data_sharding
=
args
.
data_sharding
)
elif
args
.
dataloader_type
==
"external"
:
# External dataloaders are passed through. User is expected to provide a
# torch-compatible dataloader and define samplers, if needed.
return
dataset
else
:
raise
Exception
(
'{} dataloader type is not supported.'
.
format
(
args
.
dataloader_type
))
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
,
persistent_workers
=
True
if
args
.
num_workers
>
0
else
False
,
)
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
,
drop_last
=
True
):
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
drop_last
=
drop_last
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
'no sample to consume: {}'
.
format
(
self
.
total_samples
)
assert
self
.
consumed_samples
<
self
.
total_samples
,
\
'no samples left to consume: {}, {}'
.
format
(
self
.
consumed_samples
,
self
.
total_samples
)
assert
self
.
micro_batch_size
>
0
assert
data_parallel_size
>
0
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
'data_parallel_rank should be smaller than data size: {}, '
\
'{}'
.
format
(
self
.
data_parallel_rank
,
data_parallel_size
)
def
__len__
(
self
):
return
self
.
total_samples
def
get_start_end_idx
(
self
):
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
return
start_idx
,
end_idx
def
__iter__
(
self
):
batch
=
[]
# Last batch will be dropped if drop_last is not set False
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
# Check the last partial batch and see drop_last is set
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
class
RandomSeedDataset
(
Dataset
):
def
__init__
(
self
,
dataset
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
args
.
seed
self
.
dataset
=
dataset
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
return
self
.
dataset
[
idx
]
class
MegatronPretrainingRandomSampler
:
def
__init__
(
self
,
dataset
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
,
data_sharding
):
# Keep a copy of input params for later use.
self
.
dataset
=
dataset
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_size
=
data_parallel_size
self
.
data_sharding
=
data_sharding
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
last_batch_size
=
\
self
.
total_samples
%
self
.
micro_batch_times_data_parallel_size
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
'no sample to consume: {}'
.
format
(
self
.
total_samples
)
assert
self
.
micro_batch_size
>
0
assert
data_parallel_size
>
0
assert
self
.
data_parallel_rank
<
data_parallel_size
,
\
'data_parallel_rank should be smaller than data size: {}, '
\
'{}'
.
format
(
self
.
data_parallel_rank
,
data_parallel_size
)
def
__len__
(
self
):
return
self
.
total_samples
def
__iter__
(
self
):
active_total_samples
=
self
.
total_samples
-
self
.
last_batch_size
self
.
epoch
=
self
.
consumed_samples
//
active_total_samples
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
if
isinstance
(
self
.
dataset
,
RandomSeedDataset
):
self
.
dataset
.
set_epoch
(
self
.
epoch
)
# data sharding and random sampling
if
self
.
data_sharding
:
bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_times_data_parallel_size
)
\
*
self
.
micro_batch_size
bucket_offset
=
current_epoch_samples
//
self
.
data_parallel_size
start_idx
=
self
.
data_parallel_rank
*
bucket_size
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
random_idx
=
torch
.
randperm
(
bucket_size
,
generator
=
g
).
tolist
()
idx_range
=
[
start_idx
+
x
for
x
in
random_idx
[
bucket_offset
:]]
else
:
full_bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_size
)
\
*
self
.
micro_batch_size
full_bucket_offset
=
current_epoch_samples
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
idx_range_total
=
\
torch
.
randperm
(
full_bucket_size
,
generator
=
g
).
tolist
()
idx_range_active
=
idx_range_total
[
full_bucket_offset
:]
idx_range
=
idx_range_active
[
self
.
data_parallel_rank
::
self
.
data_parallel_size
]
batch
=
[]
# Last batch if not complete will be dropped.
for
idx
in
idx_range
:
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_size
:
self
.
consumed_samples
+=
self
.
micro_batch_times_data_parallel_size
yield
batch
batch
=
[]
megatron/legacy/data/dataset_utils.py
0 → 100644
View file @
7c19b3a8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import
math
import
os
import
time
import
collections
import
numpy
as
np
import
torch
from
megatron.training
import
(
get_args
,
print_rank_0
)
from
megatron.core
import
mpu
from
megatron.core.datasets.indexed_dataset
import
IndexedDataset
DSET_TYPE_BERT
=
'standard_bert'
DSET_TYPE_ICT
=
'ict'
DSET_TYPE_T5
=
't5'
DSET_TYPE_MULTIMODAL
=
'multimodal'
DSET_TYPES
=
[
DSET_TYPE_BERT
,
DSET_TYPE_ICT
,
DSET_TYPE_T5
,
DSET_TYPE_MULTIMODAL
]
def
get_datasets_weights_and_num_samples
(
data_prefix
,
train_valid_test_num_samples
):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert
len
(
data_prefix
)
%
2
==
0
num_datasets
=
len
(
data_prefix
)
//
2
weights
=
[
0
]
*
num_datasets
prefixes
=
[
0
]
*
num_datasets
for
i
in
range
(
num_datasets
):
weights
[
i
]
=
float
(
data_prefix
[
2
*
i
])
prefixes
[
i
]
=
(
data_prefix
[
2
*
i
+
1
]).
strip
()
# Normalize weights
weight_sum
=
0.0
for
weight
in
weights
:
weight_sum
+=
weight
assert
weight_sum
>
0.0
weights
=
[
weight
/
weight_sum
for
weight
in
weights
]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
if
isinstance
(
train_valid_test_num_samples
,
list
):
datasets_train_valid_test_num_samples
=
[]
for
weight
in
weights
:
datasets_train_valid_test_num_samples
.
append
(
[
int
(
math
.
ceil
(
val
*
weight
*
1.005
))
for
val
in
train_valid_test_num_samples
])
else
:
# Used when separate dataset files are provided for train,
# valid and test
datasets_train_valid_test_num_samples
=
[
int
(
math
.
ceil
(
train_valid_test_num_samples
*
weight
*
1.005
))
for
weight
in
weights
]
return
prefixes
,
weights
,
datasets_train_valid_test_num_samples
def
get_a_and_b_segments
(
sample
,
np_rng
):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences
=
len
(
sample
)
# Make sure we always have two sentences.
assert
n_sentences
>
1
,
'make sure each sample has at least two sentences.'
# First part:
# `a_end` is how many sentences go into the `A`.
a_end
=
1
if
n_sentences
>=
3
:
# Note that randin in numpy is exclusive.
a_end
=
np_rng
.
randint
(
1
,
n_sentences
)
tokens_a
=
[]
for
j
in
range
(
a_end
):
tokens_a
.
extend
(
sample
[
j
])
# Second part:
tokens_b
=
[]
for
j
in
range
(
a_end
,
n_sentences
):
tokens_b
.
extend
(
sample
[
j
])
# Random next:
is_next_random
=
False
if
np_rng
.
random
()
<
0.5
:
is_next_random
=
True
tokens_a
,
tokens_b
=
tokens_b
,
tokens_a
return
tokens_a
,
tokens_b
,
is_next_random
def
truncate_segments
(
tokens_a
,
tokens_b
,
len_a
,
len_b
,
max_num_tokens
,
np_rng
):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert
len_a
>
0
if
len_a
+
len_b
<=
max_num_tokens
:
return
False
while
len_a
+
len_b
>
max_num_tokens
:
if
len_a
>
len_b
:
len_a
-=
1
tokens
=
tokens_a
else
:
len_b
-=
1
tokens
=
tokens_b
if
np_rng
.
random
()
<
0.5
:
del
tokens
[
0
]
else
:
tokens
.
pop
()
return
True
def
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens
=
[]
tokentypes
=
[]
# [CLS].
tokens
.
append
(
cls_id
)
tokentypes
.
append
(
0
)
# Segment A.
for
token
in
tokens_a
:
tokens
.
append
(
token
)
tokentypes
.
append
(
0
)
# [SEP].
tokens
.
append
(
sep_id
)
tokentypes
.
append
(
0
)
# Segment B.
for
token
in
tokens_b
:
tokens
.
append
(
token
)
tokentypes
.
append
(
1
)
if
tokens_b
:
# [SEP].
tokens
.
append
(
sep_id
)
tokentypes
.
append
(
1
)
return
tokens
,
tokentypes
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
def
is_start_piece
(
piece
):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return
not
piece
.
startswith
(
"##"
)
def
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
,
max_ngrams
=
3
,
do_whole_word_mask
=
True
,
favor_longer_ngram
=
False
,
do_permutation
=
False
,
geometric_dist
=
False
,
masking_style
=
"bert"
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes
=
[]
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary
=
[
0
]
*
len
(
tokens
)
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
cls_id
or
token
==
sep_id
:
token_boundary
[
i
]
=
1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
not
is_start_piece
(
vocab_id_to_token_dict
[
token
])):
cand_indexes
[
-
1
].
append
(
i
)
else
:
cand_indexes
.
append
([
i
])
if
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
token_boundary
[
i
]
=
1
output_tokens
=
list
(
tokens
)
masked_lm_positions
=
[]
masked_lm_labels
=
[]
if
masked_lm_prob
==
0
:
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
ngrams
=
np
.
arange
(
1
,
max_ngrams
+
1
,
dtype
=
np
.
int64
)
if
not
geometric_dist
:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals
=
1.
/
np
.
arange
(
1
,
max_ngrams
+
1
)
pvals
/=
pvals
.
sum
(
keepdims
=
True
)
if
favor_longer_ngram
:
pvals
=
pvals
[::
-
1
]
ngram_indexes
=
[]
for
idx
in
range
(
len
(
cand_indexes
)):
ngram_index
=
[]
for
n
in
ngrams
:
ngram_index
.
append
(
cand_indexes
[
idx
:
idx
+
n
])
ngram_indexes
.
append
(
ngram_index
)
np_rng
.
shuffle
(
ngram_indexes
)
(
masked_lms
,
masked_spans
)
=
([],
[])
covered_indexes
=
set
()
for
cand_index_set
in
ngram_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
:
continue
if
not
geometric_dist
:
n
=
np_rng
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
else
:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n
=
min
(
np_rng
.
geometric
(
0.2
),
max_ngrams
)
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
if
masking_style
==
"bert"
:
# 80% of the time, replace with [MASK]
if
np_rng
.
random
()
<
0.8
:
masked_token
=
mask_id
else
:
# 10% of the time, keep original
if
np_rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_id_list
[
np_rng
.
randint
(
0
,
len
(
vocab_id_list
))]
elif
masking_style
==
"t5"
:
masked_token
=
mask_id
else
:
raise
ValueError
(
"invalid value of masking style"
)
output_tokens
[
index
]
=
masked_token
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
masked_spans
.
append
(
MaskedLmInstance
(
index
=
index_set
,
label
=
[
tokens
[
index
]
for
index
in
index_set
]))
assert
len
(
masked_lms
)
<=
num_to_predict
np_rng
.
shuffle
(
ngram_indexes
)
select_indexes
=
set
()
if
do_permutation
:
for
cand_index_set
in
ngram_indexes
:
if
len
(
select_indexes
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
continue
n
=
np
.
random
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
while
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
select_indexes
.
add
(
index
)
assert
len
(
select_indexes
)
<=
num_to_predict
select_indexes
=
sorted
(
select_indexes
)
permute_indexes
=
list
(
select_indexes
)
np_rng
.
shuffle
(
permute_indexes
)
orig_token
=
list
(
output_tokens
)
for
src_i
,
tgt_i
in
zip
(
select_indexes
,
permute_indexes
):
output_tokens
[
src_i
]
=
orig_token
[
tgt_i
]
masked_lms
.
append
(
MaskedLmInstance
(
index
=
src_i
,
label
=
orig_token
[
src_i
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
# Sort the spans by the index of the first span
masked_spans
=
sorted
(
masked_spans
,
key
=
lambda
x
:
x
.
index
[
0
])
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
,
masked_spans
)
def
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens
=
len
(
tokens
)
padding_length
=
max_seq_length
-
num_tokens
assert
padding_length
>=
0
assert
len
(
tokentypes
)
==
num_tokens
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
# Tokens and token types.
filler
=
[
pad_id
]
*
padding_length
tokens_np
=
np
.
array
(
tokens
+
filler
,
dtype
=
np
.
int64
)
tokentypes_np
=
np
.
array
(
tokentypes
+
filler
,
dtype
=
np
.
int64
)
# Padding mask.
padding_mask_np
=
np
.
array
([
1
]
*
num_tokens
+
[
0
]
*
padding_length
,
dtype
=
np
.
int64
)
# Lables and loss mask.
labels
=
[
-
1
]
*
max_seq_length
loss_mask
=
[
0
]
*
max_seq_length
for
i
in
range
(
len
(
masked_positions
)):
assert
masked_positions
[
i
]
<
num_tokens
labels
[
masked_positions
[
i
]]
=
masked_labels
[
i
]
loss_mask
[
masked_positions
[
i
]]
=
1
labels_np
=
np
.
array
(
labels
,
dtype
=
np
.
int64
)
loss_mask_np
=
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)
return
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
def
build_train_valid_test_datasets_with_prefixes
(
train_valid_test_num_samples
,
max_seq_length
,
seed
,
train_data_prefix
=
None
,
valid_data_prefix
=
None
,
test_data_prefix
=
None
,
binary_head
=
False
,
max_seq_length_dec
=
None
,
dataset_type
=
'standard_bert'
):
print_rank_0
(
"Separate data paths provided for train, valid & test."
)
train_dataset
,
valid_dataset
,
test_dataset
=
None
,
None
,
None
# Single dataset.
if
train_data_prefix
is
not
None
:
train_dataset
=
build_dataset
(
"train"
,
train_data_prefix
,
train_valid_test_num_samples
[
0
],
max_seq_length
,
seed
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
dataset_type
)
if
valid_data_prefix
is
not
None
:
valid_dataset
=
build_dataset
(
"valid"
,
valid_data_prefix
,
train_valid_test_num_samples
[
1
],
max_seq_length
,
seed
,
False
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
dataset_type
)
if
test_data_prefix
is
not
None
:
test_dataset
=
build_dataset
(
"test"
,
test_data_prefix
,
train_valid_test_num_samples
[
2
],
max_seq_length
,
seed
,
False
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
dataset_type
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
def
build_train_valid_test_datasets
(
data_prefix
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
seed
,
binary_head
=
False
,
max_seq_length_dec
=
None
,
dataset_type
=
'standard_bert'
):
if
len
(
data_prefix
)
==
1
:
return
_build_train_valid_test_datasets
(
data_prefix
[
0
],
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
seed
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
dataset_type
)
raise
NotImplementedError
(
"Blending currently unsupported for non-GPT dataset instances"
)
def
_build_train_valid_test_datasets
(
data_prefix
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
seed
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
'standard_bert'
):
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
dataset_type
)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents
=
indexed_dataset
.
document_indices
.
shape
[
0
]
-
1
splits
=
get_train_valid_test_split_
(
splits_string
,
total_num_of_documents
)
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
'documents'
.
format
(
splits
[
index
],
splits
[
index
+
1
],
splits
[
index
+
1
]
-
splits
[
index
]))
start_index
=
indexed_dataset
.
document_indices
[
splits
[
index
]]
end_index
=
indexed_dataset
.
document_indices
[
splits
[
index
+
1
]]
print_rank_0
(
' sentence indices in [{}, {}) total of {} '
'sentences'
.
format
(
start_index
,
end_index
,
end_index
-
start_index
))
print_split_stats
(
'train'
,
0
)
print_split_stats
(
'validation'
,
1
)
print_split_stats
(
'test'
,
2
)
def
build_split_dataset
(
index
,
name
):
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr
=
indexed_dataset
.
get_document_indices
()
# Slice the doc-idx
start_index
=
splits
[
index
]
# Add +1 so we can index into the dataset to get the upper bound.
end_index
=
splits
[
index
+
1
]
+
1
# New doc_idx view.
indexed_dataset
.
set_document_indices
(
doc_idx_ptr
[
start_index
:
end_index
])
dataset
=
build_dataset
(
name
,
data_prefix
,
train_valid_test_num_samples
[
index
],
max_seq_length
,
seed
,
binary_head
,
max_seq_length_dec
,
dataset_type
,
indexed_dataset
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_document_indices
(
doc_idx_ptr
)
# Checks.
assert
indexed_dataset
.
document_indices
[
0
]
==
0
assert
indexed_dataset
.
document_indices
.
shape
[
0
]
==
\
(
total_num_of_documents
+
1
)
return
dataset
train_dataset
=
build_split_dataset
(
0
,
'train'
)
valid_dataset
=
build_split_dataset
(
1
,
'valid'
)
test_dataset
=
build_split_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
def
build_dataset
(
name
,
data_prefix
,
max_num_samples
,
max_seq_length
,
seed
,
binary_head
,
max_seq_length_dec
,
dataset_type
=
'standard_bert'
,
indexed_dataset
=
None
):
from
megatron.legacy.data.ict_dataset
import
ICTDataset
from
megatron.legacy.data.multimodal_dataset
import
MultiModalDataset
if
dataset_type
==
DSET_TYPE_BERT
or
dataset_type
==
DSET_TYPE_T5
:
raise
ValueError
(
"The Megatron-LM BERT and T5 datasets are deprecated."
)
if
dataset_type
not
in
DSET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
if
indexed_dataset
is
None
:
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
dataset_type
)
kwargs
=
dict
(
name
=
name
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
max_num_samples
=
max_num_samples
,
max_seq_length
=
max_seq_length
,
seed
=
seed
,
)
if
dataset_type
==
DSET_TYPE_ICT
:
args
=
get_args
()
title_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
dataset_type
)
dataset
=
ICTDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
query_in_block_prob
=
args
.
query_in_block_prob
,
use_one_sent_docs
=
args
.
use_one_sent_docs
,
binary_head
=
binary_head
,
**
kwargs
)
elif
dataset_type
==
DSET_TYPE_MULTIMODAL
:
args
=
get_args
()
dataset
=
MultiModalDataset
(
name
=
name
,
data_prefix
=
data_prefix
,
indexed_dataset
=
indexed_dataset
,
num_samples
=
max_num_samples
,
seq_length
=
max_seq_length
,
seed
=
seed
,
img_h
=
args
.
img_h
,
img_w
=
args
.
img_w
,
)
else
:
raise
NotImplementedError
(
"Dataset type not fully implemented."
)
return
dataset
def
get_indexed_dataset_
(
data_prefix
,
dataset_type
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
multimodal
=
dataset_type
==
DSET_TYPE_MULTIMODAL
indexed_dataset
=
IndexedDataset
(
data_prefix
,
multimodal
)
assert
indexed_dataset
.
sequence_lengths
.
shape
[
0
]
==
indexed_dataset
.
document_indices
[
-
1
]
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
document_indices
.
shape
[
0
]
-
1
))
print_rank_0
(
' number of sentences: {}'
.
format
(
indexed_dataset
.
sequence_lengths
.
shape
[
0
]))
return
indexed_dataset
def
get_train_valid_test_split_
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
splits
=
[]
if
splits_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
','
)]
elif
splits_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
assert
len
(
splits_index
)
==
4
assert
splits_index
[
-
1
]
==
size
return
splits_index
def
get_samples_mapping
(
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
name
,
binary_head
):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{:0.2f}ssp'
.
format
(
short_seq_prob
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
indexed_dataset
.
document_indices
.
dtype
==
np
.
int64
assert
indexed_dataset
.
sequence_lengths
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
# First compile and then import.
from
megatron.core.datasets
import
helpers
samples_mapping
=
helpers
.
build_mapping
(
indexed_dataset
.
document_indices
,
indexed_dataset
.
sequence_lengths
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
,
2
if
binary_head
else
1
)
print_rank_0
(
' > done building samples index maping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elasped time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
tensor
([
1
],
dtype
=
torch
.
long
,
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
assert
counts
[
0
].
item
()
==
(
torch
.
distributed
.
get_world_size
()
//
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_tensor_model_parallel_group
()))
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
megatron/legacy/data/ict_dataset.py
0 → 100644
View file @
7c19b3a8
import
itertools
import
random
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron.training
import
get_tokenizer
from
megatron.training
import
get_args
from
megatron.legacy.data.dataset_utils
import
get_indexed_dataset_
from
megatron.legacy.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
args
.
use_one_sent_docs
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
,
binary_head
=
False
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
use_titles
=
use_titles
self
.
use_one_sent_docs
=
use_one_sent_docs
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
len
(
self
.
samples_mapping
)
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
sample_data
.
as_tuple
()
if
self
.
use_titles
:
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
title_pad_offset
=
3
+
len
(
title
)
else
:
title
=
None
title_pad_offset
=
2
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
or
self
.
use_one_sent_docs
or
self
.
query_in_block_prob
==
1
# randint() is inclusive for Python rng
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context query_in_block_prob fraction of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'context_tokens'
:
context_tokens
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
}
return
sample
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
get_null_block
(
self
):
"""Get empty block and title - used in REALM pretraining"""
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
list
(
tokens
)
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
title
=
list
(
title
)
tokens
=
[
self
.
cls_id
]
+
title
+
[
self
.
sep_id
]
+
tokens
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
np
.
array
(
tokens
),
np
.
array
(
pad_mask
)
megatron/legacy/data/image_folder.py
0 → 100644
View file @
7c19b3a8
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
# added support for classes_fraction and data_per_class_fraction
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
import
os
import
os.path
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
def
has_file_allowed_extension
(
filename
:
str
,
extensions
:
Tuple
[
str
,
...])
->
bool
:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return
filename
.
lower
().
endswith
(
extensions
)
def
is_image_file
(
filename
:
str
)
->
bool
:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return
has_file_allowed_extension
(
filename
,
IMG_EXTENSIONS
)
def
make_dataset
(
directory
:
str
,
class_to_idx
:
Dict
[
str
,
int
],
data_per_class_fraction
:
float
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
List
[
Tuple
[
str
,
int
]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances
=
[]
directory
=
os
.
path
.
expanduser
(
directory
)
both_none
=
extensions
is
None
and
is_valid_file
is
None
both_something
=
extensions
is
not
None
and
is_valid_file
is
not
None
if
both_none
or
both_something
:
raise
ValueError
(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if
extensions
is
not
None
:
def
is_valid_file
(
x
:
str
)
->
bool
:
return
has_file_allowed_extension
(
x
,
cast
(
Tuple
[
str
,
...],
extensions
))
is_valid_file
=
cast
(
Callable
[[
str
],
bool
],
is_valid_file
)
for
target_class
in
sorted
(
class_to_idx
.
keys
()):
class_index
=
class_to_idx
[
target_class
]
target_dir
=
os
.
path
.
join
(
directory
,
target_class
)
if
not
os
.
path
.
isdir
(
target_dir
):
continue
local_instances
=
[]
for
root
,
_
,
fnames
in
sorted
(
os
.
walk
(
target_dir
,
followlinks
=
True
)):
for
fname
in
sorted
(
fnames
):
path
=
os
.
path
.
join
(
root
,
fname
)
if
is_valid_file
(
path
):
item
=
path
,
class_index
local_instances
.
append
(
item
)
instances
.
extend
(
local_instances
[
0
:
int
(
len
(
local_instances
)
*
data_per_class_fraction
)])
return
instances
class
DatasetFolder
(
VisionDataset
):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def
__init__
(
self
,
root
:
str
,
loader
:
Callable
[[
str
],
Any
],
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
classes_fraction
=
1.0
,
data_per_class_fraction
=
1.0
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
None
:
super
(
DatasetFolder
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
classes_fraction
=
classes_fraction
self
.
data_per_class_fraction
=
data_per_class_fraction
classes
,
class_to_idx
=
self
.
_find_classes
(
self
.
root
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
self
.
data_per_class_fraction
,
extensions
,
is_valid_file
)
if
len
(
samples
)
==
0
:
msg
=
"Found 0 files in subfolders of: {}
\n
"
.
format
(
self
.
root
)
if
extensions
is
not
None
:
msg
+=
"Supported extensions are: {}"
.
format
(
","
.
join
(
extensions
))
raise
RuntimeError
(
msg
)
self
.
loader
=
loader
self
.
extensions
=
extensions
self
.
total
=
len
(
samples
)
self
.
classes
=
classes
self
.
class_to_idx
=
class_to_idx
self
.
samples
=
samples
self
.
targets
=
[
s
[
1
]
for
s
in
samples
]
@
staticmethod
def
make_dataset
(
directory
:
str
,
class_to_idx
:
Dict
[
str
,
int
],
data_per_class_fraction
:
float
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
List
[
Tuple
[
str
,
int
]]:
return
make_dataset
(
directory
,
class_to_idx
,
data_per_class_fraction
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
)
def
_find_classes
(
self
,
dir
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
all_classes
=
[
d
.
name
for
d
in
os
.
scandir
(
dir
)
if
d
.
is_dir
()]
classes
=
all_classes
[
0
:
int
(
len
(
all_classes
)
*
self
.
classes_fraction
)]
classes
.
sort
()
class_to_idx
=
{
cls_name
:
i
for
i
,
cls_name
in
enumerate
(
classes
)}
return
classes
,
class_to_idx
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
curr_index
=
index
for
x
in
range
(
self
.
total
):
try
:
path
,
target
=
self
.
samples
[
curr_index
]
sample
=
self
.
loader
(
path
)
break
except
Exception
as
e
:
curr_index
=
np
.
random
.
randint
(
0
,
self
.
total
)
if
self
.
transform
is
not
None
:
sample
=
self
.
transform
(
sample
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
sample
,
target
def
__len__
(
self
)
->
int
:
return
len
(
self
.
samples
)
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
,
'.tiff'
,
'.webp'
)
def
pil_loader
(
path
:
str
)
->
Image
.
Image
:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with
open
(
path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
return
img
.
convert
(
'RGB'
)
# TODO: specify the return type
def
accimage_loader
(
path
:
str
)
->
Any
:
import
accimage
try
:
return
accimage
.
Image
(
path
)
except
IOError
:
# Potentially a decoding problem, fall back to PIL.Image
return
pil_loader
(
path
)
def
default_loader
(
path
:
str
)
->
Any
:
from
torchvision
import
get_image_backend
if
get_image_backend
()
==
'accimage'
:
return
accimage_loader
(
path
)
else
:
return
pil_loader
(
path
)
class
ImageFolder
(
DatasetFolder
):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def
__init__
(
self
,
root
:
str
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
classes_fraction
=
1.0
,
data_per_class_fraction
=
1.0
,
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
super
(
ImageFolder
,
self
).
__init__
(
root
,
loader
,
IMG_EXTENSIONS
if
is_valid_file
is
None
else
None
,
transform
=
transform
,
target_transform
=
target_transform
,
classes_fraction
=
classes_fraction
,
data_per_class_fraction
=
data_per_class_fraction
,
is_valid_file
=
is_valid_file
)
self
.
imgs
=
self
.
samples
megatron/legacy/data/multimodal_dataset.py
0 → 100644
View file @
7c19b3a8
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
PIL
import
Image
,
UnidentifiedImageError
import
numpy
as
np
import
io
import
torch
try
:
from
torchvision.transforms
import
InterpolationMode
BICUBIC
=
InterpolationMode
.
BICUBIC
except
ImportError
:
BICUBIC
=
Image
.
BICUBIC
from
torchvision.transforms
import
Compose
,
ToTensor
,
Normalize
,
ToPILImage
,
RandomResizedCrop
,
Resize
def
_convert_image_to_rgb
(
image
):
return
image
.
convert
(
"RGB"
)
def
_transform
(
img_h
,
img_w
):
return
Compose
([
ToPILImage
(),
RandomResizedCrop
((
img_h
,
img_w
),
scale
=
(
0.5
,
1.0
),
interpolation
=
BICUBIC
),
_convert_image_to_rgb
,
ToTensor
(),
Normalize
((
0.48145466
,
0.4578275
,
0.40821073
),
(
0.26862954
,
0.26130258
,
0.27577711
)),
])
class
MultiModalDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
name
,
data_prefix
,
indexed_dataset
,
num_samples
,
seq_length
,
seed
,
img_h
,
img_w
):
self
.
name
=
name
self
.
indexed_dataset
=
indexed_dataset
self
.
doc_idx
=
indexed_dataset
.
get_document_indices
()
self
.
visual_transform
=
_transform
(
img_h
,
img_w
)
def
__len__
(
self
):
return
self
.
indexed_dataset
.
sequence_lengths
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
text_sample
,
mode
=
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
idx
])
assert
mode
==
0
img_sample
,
mode
=
self
.
indexed_dataset
.
get
(
self
.
doc_idx
[
idx
]
+
1
)
assert
mode
==
1
img_pad
=
img_sample
[
0
].
item
()
xs
=
img_sample
[
1
:].
tobytes
(
order
=
'C'
)
xs
=
xs
[:
len
(
xs
)
-
img_pad
]
img_sample
=
np
.
array
(
Image
.
open
(
io
.
BytesIO
(
xs
)))
img_sample
=
self
.
visual_transform
(
img_sample
).
reshape
(
-
1
)
return
{
'text'
:
np
.
array
(
text_sample
,
dtype
=
np
.
int64
),
'img'
:
np
.
array
(
img_sample
,
dtype
=
np
.
float32
)}
Prev
1
…
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment