Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d3bb1a06
Commit
d3bb1a06
authored
Nov 28, 2020
by
mohammad
Browse files
added blendable dataset
parent
ea81d62f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
0 deletions
+139
-0
megatron/data/blendable_dataset.py
megatron/data/blendable_dataset.py
+75
-0
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+64
-0
No files found.
megatron/data/blendable_dataset.py
0 → 100644
View file @
d3bb1a06
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import
time
import
numpy
as
np
import
torch
from
megatron
import
print_rank_0
from
megatron
import
mpu
class
BlendableDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
datasets
,
weights
):
self
.
datasets
=
datasets
num_datasets
=
len
(
datasets
)
assert
num_datasets
==
len
(
weights
)
self
.
size
=
0
for
dataset
in
self
.
datasets
:
self
.
size
+=
len
(
dataset
)
# Normalize weights.
weights
=
np
.
array
(
weights
,
dtype
=
np
.
float64
)
sum_weights
=
np
.
sum
(
weights
)
assert
sum_weights
>
0.0
weights
/=
sum_weights
# Build indecies.
start_time
=
time
.
time
()
assert
num_datasets
<
255
self
.
dataset_index
=
np
.
zeros
(
self
.
size
,
dtype
=
np
.
uint8
)
self
.
dataset_sample_index
=
np
.
zeros
(
self
.
size
,
dtype
=
np
.
int64
)
if
torch
.
distributed
.
get_rank
()
==
0
:
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
tmp
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
tmp
,
group
=
mpu
.
get_data_parallel_group
())
from
megatron.data
import
helpers
helpers
.
build_blending_indices
(
self
.
dataset_index
,
self
.
dataset_sample_index
,
weights
,
num_datasets
,
self
.
size
,
torch
.
distributed
.
get_rank
()
==
0
)
print_rank_0
(
'> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'
.
format
(
time
.
time
()
-
start_time
))
def
__len__
(
self
):
return
self
.
size
def
__getitem__
(
self
,
idx
):
dataset_idx
=
self
.
dataset_index
[
idx
]
sample_idx
=
self
.
dataset_sample_index
[
idx
]
return
self
.
datasets
[
dataset_idx
][
sample_idx
]
megatron/data/helpers.cpp
View file @
d3bb1a06
...
@@ -33,6 +33,69 @@ using namespace std;
...
@@ -33,6 +33,69 @@ using namespace std;
const
int32_t
LONG_SENTENCE_LEN
=
512
;
const
int32_t
LONG_SENTENCE_LEN
=
512
;
void
build_blending_indices
(
py
::
array_t
<
uint8_t
>&
dataset_index
,
py
::
array_t
<
int64_t
>&
dataset_sample_index
,
const
py
::
array_t
<
double
>&
weights
,
const
int32_t
num_datasets
,
const
int64_t
size
,
const
bool
verbose
)
{
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if
(
verbose
)
{
std
::
cout
<<
"> building indices for blendable datasets ..."
<<
std
::
endl
;
}
// Get the pointer access without the checks.
auto
dataset_index_ptr
=
dataset_index
.
mutable_unchecked
<
1
>
();
auto
dataset_sample_index_ptr
=
dataset_sample_index
.
mutable_unchecked
<
1
>
();
auto
weights_ptr
=
weights
.
unchecked
<
1
>
();
// Initialize buffer for number of samples used for each dataset.
int64_t
current_samples
[
num_datasets
];
for
(
int64_t
i
=
0
;
i
<
num_datasets
;
++
i
)
{
current_samples
[
i
]
=
0
;
}
// For each sample:
for
(
int64_t
sample_idx
=
0
;
sample_idx
<
size
;
++
sample_idx
)
{
// Determine where the max error in sampling is happening.
double
sample_idx_double
=
std
::
max
(
static_cast
<
double
>
(
sample_idx
),
1.0
);
int64_t
max_error_index
=
0
;
double
max_error
=
weights_ptr
[
0
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
0
]);
for
(
int64_t
dataset_idx
=
1
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
double
error
=
weights_ptr
[
dataset_idx
]
*
sample_idx_double
-
static_cast
<
double
>
(
current_samples
[
dataset_idx
]);
if
(
error
>
max_error
)
{
max_error
=
error
;
max_error_index
=
dataset_idx
;
}
}
// Populate the indices.
dataset_index_ptr
[
sample_idx
]
=
static_cast
<
uint8_t
>
(
max_error_index
);
dataset_sample_index_ptr
[
sample_idx
]
=
current_samples
[
max_error_index
];
// Update the total samples.
current_samples
[
max_error_index
]
+=
1
;
}
// print info
if
(
verbose
)
{
std
::
cout
<<
" > sample ratios:"
<<
std
::
endl
;
for
(
int64_t
dataset_idx
=
0
;
dataset_idx
<
num_datasets
;
++
dataset_idx
)
{
double
ratio
=
static_cast
<
double
>
(
current_samples
[
dataset_idx
])
/
static_cast
<
double
>
(
size
);
std
::
cout
<<
" dataset "
<<
dataset_idx
<<
", input: "
<<
weights_ptr
[
dataset_idx
]
<<
", achieved: "
<<
ratio
<<
std
::
endl
;
}
}
}
py
::
array
build_sample_idx
(
const
py
::
array_t
<
int32_t
>&
sizes_
,
py
::
array
build_sample_idx
(
const
py
::
array_t
<
int32_t
>&
sizes_
,
const
py
::
array_t
<
int32_t
>&
doc_idx_
,
const
py
::
array_t
<
int32_t
>&
doc_idx_
,
const
int32_t
seq_length
,
const
int32_t
seq_length
,
...
@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
...
@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
m
.
def
(
"build_mapping"
,
&
build_mapping
);
m
.
def
(
"build_mapping"
,
&
build_mapping
);
m
.
def
(
"build_blocks_mapping"
,
&
build_blocks_mapping
);
m
.
def
(
"build_blocks_mapping"
,
&
build_blocks_mapping
);
m
.
def
(
"build_sample_idx"
,
&
build_sample_idx
);
m
.
def
(
"build_sample_idx"
,
&
build_sample_idx
);
m
.
def
(
"build_blending_indices"
,
&
build_blending_indices
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment