Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3aca1415
Commit
3aca1415
authored
Apr 29, 2024
by
liangjing
Browse files
Merge branch 'megatron-lm_dtk24.04' into 'main'
Megatron lm dtk24.04 See merge request
!1
parents
0024a5c6
1005e9d3
Pipeline
#1806
passed with stage
Changes
204
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
20 deletions
+70
-20
tools/retro/query/retro_dataset.py
tools/retro/query/retro_dataset.py
+30
-14
tools/retro/query/utils.py
tools/retro/query/utils.py
+17
-0
tools/retro/utils.py
tools/retro/utils.py
+13
-4
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+10
-2
No files found.
tools/retro/
pretraining
/retro_dataset.py
→
tools/retro/
query
/retro_dataset.py
View file @
3aca1415
...
...
@@ -5,11 +5,12 @@ import os
import
torch
from
megatron
import
get_args
,
get_retro_args
from
tools.bert_embedding.utils
import
get_index_p
ath
_m
ap
from
tools.bert_embedding.utils
import
BlockP
ath
M
ap
from
tools.retro.db.utils
import
get_merged_train_dataset
as
get_db_dataset
from
tools.retro.external_libs
import
h5py
from
.chunk_dataset
import
get_chunk_dataset_map
from
.utils
import
get_neighbor_dirname
class
RetroDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
@@ -100,7 +101,7 @@ class RetroDataset(torch.utils.data.Dataset):
return
sample
def
get_retro_datasets
():
def
get_retro_datasets
(
verify_sizes
=
True
):
'''Get train, valid, test retro datasets.'''
args
=
get_args
()
...
...
@@ -116,24 +117,39 @@ def get_retro_datasets():
chunk_dataset
=
chunk_ds_info
[
"data"
]
neighbor_dir
=
chunk_ds_info
[
"neighbor_dir"
]
neighbor_path_map
=
get_index_path_map
(
neighbor_dir
)
neighbor_path_map
=
BlockPathMap
.
from_dir
(
neighbor_dir
,
retro_args
.
retro_block_size
)
# Verify dataset prefixes.
sample_prefix
=
chunk_dataset
.
sample_dataset
.
datasets
[
0
].
index_prefix
neighbor_prefix
=
os
.
path
.
basename
(
neighbor_dir
)
assert
sample_prefix
==
neighbor_prefix
,
\
expected_dir
=
get_neighbor_dirname
(
data_key
,
chunk_dataset
.
sample_dataset
)
assert
expected_dir
==
neighbor_dir
,
\
"inconsistent dataset source; '%s' vs. '%s'."
%
\
(
sample_prefix
,
neighbor_
prefix
)
(
expected_dir
,
neighbor_
dir
)
# Verify num chunks.
n_sample_chunks
=
len
(
chunk_dataset
)
n_neighbor_chunks
=
len
(
neighbor_path_map
.
id_index_map
)
if
n_sample_chunks
!=
n_neighbor_chunks
:
print
(
"neighbor_dir : %s"
%
neighbor_dir
)
print
(
"neighbor_path_map : %s"
%
neighbor_path_map
)
raise
Exception
(
"num sampled chunks (%d) != num neighbor chunks (%d)"
%
(
n_sample_chunks
,
n_neighbor_chunks
))
n_neighbor_chunks
=
neighbor_path_map
.
max_idx
if
not
os
.
path
.
isdir
(
neighbor_dir
):
if
torch
.
distributed
.
get_rank
()
==
0
:
raise
Exception
(
"neighbor directory '%s' not found; please "
"compare --train-samples, --seq-length, --seed, "
"--eval-iters, and --eval-interval, with "
"retro preprocessing args."
%
neighbor_dir
)
torch
.
distributed
.
barrier
()
exit
()
if
verify_sizes
and
n_sample_chunks
!=
n_neighbor_chunks
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"neighbor_dir : %s"
%
neighbor_dir
)
print
(
"neighbor_path_map : %s"
%
neighbor_path_map
)
raise
Exception
(
"num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?"
%
(
n_sample_chunks
,
n_neighbor_chunks
))
torch
.
distributed
.
barrier
()
exit
()
# Retro dataset.
retro_dataset_map
[
data_key
]
=
RetroDataset
(
...
...
tools/retro/query/utils.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
hashlib
import
os
from
megatron
import
get_retro_args
def
get_query_workdir
():
args
=
get_retro_args
()
return
os
.
path
.
join
(
args
.
retro_workdir
,
"query"
)
def
get_neighbor_dirname
(
key
,
dataset
):
hashes
=
","
.
join
([
d
.
desc_hash
for
d
in
dataset
.
datasets
])
hash
=
hashlib
.
md5
(
hashes
.
encode
()).
hexdigest
()
return
os
.
path
.
join
(
get_query_workdir
(),
os
.
path
.
basename
(
f
"
{
key
}
_
{
hash
}
"
))
tools/retro/utils.py
View file @
3aca1415
...
...
@@ -8,6 +8,7 @@ from megatron import get_retro_args
from
megatron.tokenizer.tokenizer
import
(
_BertWordPieceTokenizer
,
_GPT2BPETokenizer
,
_GPTSentencePieceTokenizer
,
)
...
...
@@ -28,10 +29,18 @@ def get_num_chunks_per_sample():
def
get_gpt_tokenizer
():
'''GPT (BPE) tokenizer.'''
args
=
get_retro_args
()
return
_GPT2BPETokenizer
(
vocab_file
=
args
.
retro_gpt_vocab_file
,
merge_file
=
args
.
retro_gpt_merge_file
,
)
tokenizer_type
=
args
.
retro_gpt_tokenizer_type
if
tokenizer_type
==
"GPT2BPETokenizer"
:
assert
args
.
retro_gpt_vocab_file
and
args
.
retro_gpt_merge_file
return
_GPT2BPETokenizer
(
vocab_file
=
args
.
retro_gpt_vocab_file
,
merge_file
=
args
.
retro_gpt_merge_file
,
)
elif
tokenizer_type
==
'GPTSentencePieceTokenizer'
:
assert
args
.
retro_gpt_tokenizer_model
is
not
None
return
_GPTSentencePieceTokenizer
(
args
.
retro_gpt_tokenizer_model
)
else
:
raise
Exception
(
"unrecognized gpt tokenizer, '%s'."
%
tokenizer_type
)
def
get_bert_tokenizer
():
...
...
tools/run_text_generation_server.py
View file @
3aca1415
...
...
@@ -13,6 +13,7 @@ from megatron.checkpointing import load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.arguments
import
core_transformer_config_from_args
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
...
...
@@ -21,8 +22,10 @@ import torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
config
=
core_transformer_config_from_args
(
get_args
())
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
model
=
GPTModel
(
config
,
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
...
...
@@ -37,6 +40,8 @@ def add_text_generate_args(parser):
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--port"
,
type
=
int
,
default
=
5000
,
help
=
'port for text generation server to run on'
)
return
parser
...
...
@@ -50,6 +55,9 @@ if __name__ == "__main__":
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
print_rank_0
(
"WARNING: Forcing exit_on_missing_checkpoint to True for text "
"generation."
)
args
.
exit_on_missing_checkpoint
=
True
# Set up model and load checkpoint
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
...
...
@@ -60,7 +68,7 @@ if __name__ == "__main__":
model
=
model
[
0
]
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
server
=
MegatronServer
(
model
)
server
.
run
(
"0.0.0.0"
)
server
.
run
(
"0.0.0.0"
,
port
=
args
.
port
)
while
True
:
choice
=
torch
.
cuda
.
LongTensor
(
1
)
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment