Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
import torch
from tests.test_utilities import Utils
from tests.unit_tests.test_utilities import Utils
from megatron.core import ModelParallelConfig
import megatron.core.pipeline_parallel.schedules as schedule
from pytest_mock import mocker
import pytest
......@@ -20,8 +21,8 @@ def test_get_forward_backward_func():
def test_deallocate_output_tensor():
out = torch.tensor([[1, 2, 3], [4, 5, 6]])
schedule.deallocate_output_tensor(out)
assert(out.nelement() == 1)
assert(out.nelement() == 6)
"""
def test_forward_backward_func_without_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
......@@ -45,12 +46,18 @@ def test_forward_backward_func_without_pipeline_parallel(mocker):
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
config = ModelParallelConfig(
pipeline_model_parallel_size = 1
)
model.config = config
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=None,
model=[model],
num_microbatches=4,
seq_length=None,
micro_batch_size=None,
forward_only=False)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
......@@ -83,6 +90,12 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
sequence_length = 512
micro_batch_size = 8
hidden_size = 256
config = ModelParallelConfig(
pipeline_model_parallel_size = 4,
sequence_parallel = False
)
model.config = config
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
......@@ -90,9 +103,8 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
dtype=torch.float32,
model=[model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=False,
seq_length=sequence_length,
micro_batch_size=micro_batch_size,
forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
......@@ -101,7 +113,7 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
def test_forward_backward_func_with_interleaving(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.enums import ModelType
......@@ -186,4 +198,4 @@ def test_forward_backward_func_with_interleaving(mocker):
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
\ No newline at end of file
"""
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.attention import SelfAttention
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelAttention:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_attention = SelfAttention(self.transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
assert isinstance(self.parallel_attention, SelfAttention)
assert self.parallel_attention.layer_number == 1
num_weights = sum([p.numel() for p in self.parallel_attention.parameters()])
assert num_weights == 648
def test_cpu_forward(self):
# we can't currently do this because the global memory buffer is on GPU
pass
def test_gpu_forward(self):
config = self.parallel_attention.config
sequence_length = 32
micro_batch_size = 2
self.parallel_attention.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
output, bias = self.parallel_attention(hidden_states, attention_mask)
assert config.recompute_granularity is None
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size
def test_checkpointed_gpu_forward(self):
transformer_config = self.transformer_config
transformer_config.recompute_granularity='selective'
checkpointed_parallel_attention = SelfAttention(transformer_config)
config = checkpointed_parallel_attention.config
sequence_length = 32
micro_batch_size = 2
checkpointed_parallel_attention.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones(
(sequence_length, micro_batch_size, checkpointed_parallel_attention.config.hidden_size)
)
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
output, bias = checkpointed_parallel_attention(hidden_states, attention_mask)
assert config.recompute_granularity == 'selective'
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.attention import CrossAttention
"""
@pytest.fixture
def core_attention(transformer_config):
return CrossAttention(transformer_config)
class TestCoreAttention:
def test_constructor(self, core_attention):
assert isinstance(core_attention, CrossAttention)
assert core_attention.layer_number == 1
num_weights = sum([p.numel() for p in core_attention.parameters()])
assert num_weights == 0
def test_cpu_forward(self, core_attention):
# we can't currently do this because the global memory buffer is on GPU
pass
def test_gpu_forward(self, core_attention):
# destroy_global_memory_buffer()
# _set_global_memory_buffer()
# model_parallel_cuda_manual_seed(123)
core_attention.cuda()
config = core_attention.config
sequence_length = 32
micro_batch_size = 2
# query_layer (float): [sequence_length, micro_batch_size, num_attention_heads, hidden_size / num_attention_heads]
query_layer = torch.ones(
(
sequence_length,
micro_batch_size,
config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
)
).cuda()
key_layer = torch.ones_like(query_layer).cuda()
value_layer = torch.ones_like(query_layer).cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
context_layer = core_attention(
query_layer=query_layer, key_layer=key_layer, value_layer=value_layer, attention_mask=attention_mask
)
assert context_layer.shape[0] == sequence_length
assert context_layer.shape[1] == micro_batch_size
assert context_layer.shape[2] == config.hidden_size
assert context_layer.device.type == 'cuda'
assert context_layer.dtype == torch.float32
"""
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.mlp import MLP
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelMLP:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.mlp = MLP(transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
assert isinstance(self.mlp, MLP)
num_weights = sum([p.numel() for p in self.mlp.parameters()])
assert num_weights == 1236
"""
def test_cpu_forward(self, mlp):
# [sequence length, micro batch size, hidden size]
hidden_states = torch.ones((32, 2, mlp.config.hidden_size))
output, output_bias = mlp(hidden_states)
assert output.shape[0] == 32
assert output.shape[1] == 2
assert output.shape[2] == mlp.config.hidden_size
assert output_bias.shape[0] == mlp.config.hidden_size
assert output.dtype == torch.float32
"""
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_gpu_forward(self):
mlp = self.mlp
mlp.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((32, 2, mlp.config.hidden_size))
hidden_states = hidden_states.cuda()
output, output_bias = mlp(hidden_states)
assert output.shape[0] == 32
assert output.shape[1] == 2
assert output.shape[2] == mlp.config.hidden_size
assert output_bias.shape[0] == mlp.config.hidden_size
assert output.dtype == torch.float32
assert output.device.type == 'cuda'
assert output_bias.device.type == 'cuda'
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.module import Float16Module, MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()
class DummyModule(MegatronModule):
# def __init__(self, config: TransformerConfig, share_embeddings_and_output_weights=True):
def __init__(self, config: TransformerConfig):
super().__init__(config)
self.linear = torch.nn.modules.Linear(in_features=2, out_features=1)
def forward(self, x):
return self.linear(x)
class TestMegatronModule:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.megatron_module = DummyModule(config=transformer_config).cuda()
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_megatron_module(self):
megatron_module = self.megatron_module
assert megatron_module
assert megatron_module.config.hidden_size == 12
assert megatron_module.config.ffn_hidden_size == 48
assert megatron_module.linear.weight.dtype == torch.float32
x = torch.ones((2, 2)).cuda()
assert megatron_module(x).dtype == torch.float32
# TODO: test bad configs actually fail
# failed_module = megatron_module
# failed_module.fp16 = True
# failed_module.bf16 = True
class TestFloat16Module:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.megatron_module = DummyModule(config=self.transformer_config).cuda()
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_fp16_module(self):
transformer_config = self.transformer_config
megatron_module = self.megatron_module
transformer_config.fp16 = True
fp16_module = Float16Module(config=transformer_config, module=megatron_module)
assert fp16_module
assert fp16_module.config.hidden_size == 12
assert fp16_module.config.ffn_hidden_size == 48
assert fp16_module.module.linear.weight.dtype == torch.float16
x = torch.ones((2, 2)).cuda()
# inputs are converted to fp16 then outputs are converted to fp32
assert fp16_module(x).dtype == torch.float32
pytest.mark.skipif(
not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='bfloat16 is not supported on this device'
)
def test_bf16_module(self):
transformer_config = self.transformer_config
megatron_module = self.megatron_module
transformer_config.bf16 = True
bf16_module = Float16Module(config=transformer_config, module=megatron_module)
assert bf16_module
assert bf16_module.config.hidden_size == 12
assert bf16_module.config.ffn_hidden_size == 48
assert bf16_module.module.linear.weight.dtype == torch.bfloat16
x = torch.ones((2, 2)).cuda()
# inputs are converted to bf16 then outputs are converted to fp32
assert bf16_module(x).dtype == torch.float32
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
import pytest
import torch
from megatron.core import dist_checkpointing
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.transformer_block import TransformerBlock
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
class TestParallelTransformerBlock:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_transformer_block = TransformerBlock(self.transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
parallel_transformer_block = self.parallel_transformer_block
assert isinstance(parallel_transformer_block, TransformerBlock)
num_weights = sum([p.numel() for p in parallel_transformer_block.parameters()])
assert num_weights == 3792
assert parallel_transformer_block.num_layers_per_pipeline_rank == 2
assert len(parallel_transformer_block.layers) == 2
layer_0: TransformerLayer = parallel_transformer_block._get_layer(0)
assert layer_0.layer_number == 1
layer_1: TransformerLayer = parallel_transformer_block._get_layer(1)
assert layer_1.layer_number == 2
def test_gpu_forward(self):
parallel_transformer_block = self.parallel_transformer_block
config: TransformerConfig = parallel_transformer_block.config
sequence_length = 32
micro_batch_size = 2
parallel_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = parallel_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
def test_gpu_forward_full_checkpoint(self):
transformer_config = self.transformer_config
config = transformer_config
config.recompute_granularity = 'full'
config.recompute_method = 'block'
config.recompute_num_layers = config.num_layers
full_transformer_block = TransformerBlock(config)
assert full_transformer_block.config.recompute_granularity == 'full'
assert full_transformer_block.config.recompute_method == 'block'
sequence_length = 32
micro_batch_size = 2
full_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = full_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
def test_gpu_forward_selective_checkpoint(self):
transformer_config = self.transformer_config
config = transformer_config
config.recompute_granularity = 'selective'
selective_transformer_block = TransformerBlock(config)
assert selective_transformer_block.config.recompute_granularity == 'selective'
assert selective_transformer_block.checkpoint_core_attention
sequence_length = 32
micro_batch_size = 2
selective_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = selective_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelTransformerLayer:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_transformer_layer = TransformerLayer(transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
parallel_transformer_layer = self.parallel_transformer_layer
assert isinstance(parallel_transformer_layer, TransformerLayer)
assert parallel_transformer_layer.layer_number == 1
num_weights = sum([p.numel() for p in parallel_transformer_layer.parameters()])
assert num_weights == 1884
def test_gpu_forward(self):
parallel_transformer_layer = self.parallel_transformer_layer
config: TransformerConfig = parallel_transformer_layer.config
sequence_length = 32
micro_batch_size = 2
parallel_transformer_layer.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = parallel_transformer_layer(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# for now we just format core
black ${SCRIPT_DIR}/../megatron/core
isort ${SCRIPT_DIR}/../megatron/core
......@@ -11,9 +11,10 @@ from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.arguments import core_transformer_config_from_args
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.model import BertModel
from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer
from .dataset import BertEmbeddingDataset
......@@ -28,8 +29,10 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0(" > build Bert model.")
args = get_args()
config = core_transformer_config_from_args(args)
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
config=config,
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
......
......@@ -104,14 +104,14 @@ def get_missing_blocks(workdir, n_samples, block_size,
try:
f = h5py.File(path, "r")
except:
raise Exception("unable to open/validate '%s'." % path)
# raise Exception("unable to open/validate '%s'." % path)
os.remove(path)
continue
try:
validate(f)
except:
raise Exception("delete block file.")
# raise Exception("delete block file '%s'." % path)
os.remove(path)
finally:
f.close()
......@@ -156,53 +156,38 @@ def get_missing_blocks_by_rank(workdir, n_samples, block_size,
return len(missing_blocks), rank_missing_blocks
class IdPathMap:
'''Maps indexes to the containing block path.
class BlockPathMap:
'''Map an index to its containing block path.
This class optimizing the mapping of a large number of indexes to the
path of its containing block. For example, with block_size 1M, this class
stores 1/1M as many (long) path strings, saving memory.
The common use for this class is to have a directory of files containing
blocks of processed data, of uniform block size (e.g., 100k samples per
file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
'''
def __init__(self, paths):
self.paths = paths
self.path_index_map = {p:i for i,p in enumerate(paths)}
self.id_index_map = {}
@classmethod
def from_dir(cls, _dir, block_size, ext="hdf5"):
'''Get list of block files, and create map.'''
assert os.path.isdir(_dir), f"directory not found, '{_dir}'."
return cls(sorted(glob.glob(_dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths, block_size):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [ int(i) for i in name.split("-") ]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self):
return "%d paths; %d ids" % (len(self.paths), len(self.id_index_map))
def add(self, id, path):
'''Map index to a path.'''
self.id_index_map[id] = self.path_index_map[path]
def __contains__(self, idx):
'''Index added to this object?'''
return idx in self.id_index_map
return "%d paths" % len(self.block_path_map)
def __getitem__(self, idx):
'''Get path from index.'''
return self.paths[self.id_index_map[idx]]
def path_to_range(path):
'''Parse start/end indexes from block path name (e.g., 00010-00011.hdf5 ->
(10, 11).'''
return tuple([
int(i) for i in os.path.splitext(
os.path.basename(path))[0].split("-")])
def get_index_path_map(_dir):
'''Map contained indexes to block file path (on disk).'''
paths = sorted(glob.glob(_dir + "/*.hdf5"))
# Build index-path map.
idx_path_map = IdPathMap(paths)
for path in paths:
start_idx, end_idx = path_to_range(path)
for idx in range(start_idx, end_idx):
idx_path_map.add(idx, path)
return idx_path_map
'''Get block path from index.'''
block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
......@@ -55,7 +55,7 @@ def _load_checkpoint(queue, args):
]
margs = parse_args()
margs = load_args_from_checkpoint(margs)
margs, checkpoint_args = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
......@@ -63,12 +63,15 @@ def _load_checkpoint(queue, args):
margs = validate_args(margs)
def check_for_arg(arg_name):
def check_for_arg(arg_name, default=None):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
if default is not None:
setattr(margs, arg_name, default)
else:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
......@@ -77,10 +80,13 @@ def _load_checkpoint(queue, args):
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('position_embedding_type')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype')
check_for_arg('swiglu', False)
# Determine how to make our models
if args.model_type == 'GPT':
......@@ -97,18 +103,38 @@ def _load_checkpoint(queue, args):
consumed_train_samples = None
consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process):
def get_models(count, dtype):
nonlocal consumed_train_samples
nonlocal consumed_valid_samples
models = []
model_array_len = margs.virtual_pipeline_model_parallel_size
if model_array_len is None:
model_array_len = 1
models = [[] for _ in range(model_array_len)]
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
for rank in range(count):
mpu.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
if margs.virtual_pipeline_model_parallel_size is not None:
model_ = []
for i in range(margs.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider(
pre_process=pre_process,
post_process=post_process
).to(dtype)
model_.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model_rank = 0
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
model_ = model_[0]
if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples)
else:
......@@ -117,17 +143,14 @@ def _load_checkpoint(queue, args):
assert(margs.consumed_valid_samples == consumed_valid_samples)
else:
consumed_valid_samples = margs.consumed_valid_samples
models.append(model_)
for vp_rank in range(model_array_len):
models[vp_rank].append(model_[vp_rank])
return models
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(margs)
set_global_variables(margs, build_tokenizer=False)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
......@@ -146,6 +169,9 @@ def _load_checkpoint(queue, args):
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
# metadata
md = types.SimpleNamespace()
......@@ -159,15 +185,20 @@ def _load_checkpoint(queue, args):
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.checkpoint_args = checkpoint_args
# Get first pipe stage
mpu.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
all_models = [get_models(tp_size, md.params_dtype)]
models = all_models[0][0]
md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples
......@@ -180,59 +211,83 @@ def _load_checkpoint(queue, args):
# Send embeddings
message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = models[0].language_model.embedding.position_embeddings.weight.data
else:
assert not hasattr(models[0].language_model.embedding, 'position_embeddings')
queue_put("embeddings", message)
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
for vp_rank in range(vp_size):
mpu.set_virtual_pipeline_model_parallel_rank(vp_rank)
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.set_pipeline_model_parallel_rank(pp_rank)
if vp_rank == 0:
all_models.append(get_models(tp_size, md.params_dtype))
models = all_models[pp_rank][vp_rank]
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
if md.linear_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
# Handle gated linear units
if md.swiglu:
# concat all the first halves ('W's) and all the second halves ('V's)
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# simple concat of the rest
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
message = {
......@@ -241,6 +296,15 @@ def _load_checkpoint(queue, args):
}
queue_put("final layernorm", message)
if md.output_layer:
message = {
"weight": torch.cat(
[models[tp_rank].language_model.output_layer.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
queue_put("output layer", message)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
message = {
......
......@@ -96,6 +96,7 @@ def save_checkpoint(queue, args):
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--position-embedding-type', str(md.position_embedding_type),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
......@@ -121,12 +122,47 @@ def save_checkpoint(queue, args):
elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16')
if md.output_layer:
sys.argv.append('--untie-embeddings-and-output-weights')
if not md.linear_bias:
sys.argv.append('--disable-bias-linear')
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
margs = parse_args()
if hasattr (md, 'checkpoint_args'):
# These are arguments that we are either changing, or cause problems for validation if they are set
# Note that some of these deal with T5 so will need to be changed if we support T5.
args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype',
'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size',
'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion',
'sequence_parallel', 'async_tensor_model_parallel_allreduce',
'no_load_optim', 'no_load_rng', 'no_save_optim', 'no_save_rng',
'vocab_file', 'tokenizer_model',
'save_interval', 'save',
'perform_initialization', 'use_cpu_initialization',
'encoder_num_layers', 'encoder_seq_length',
'distribute_saved_activations',
'train_iters', 'lr_decay_iters', 'lr_warmup_iters', 'lr_warmup_fraction',
'start_weight_decay', 'end_weight_decay']
for arg, value in vars(md.checkpoint_args).items():
if arg in args_to_keep:
continue
if not hasattr(margs, arg):
print(f"Checkpoint had argument {arg} but new arguments does not have this.")
continue
if getattr(margs, arg) != value:
print(f"Overwriting default {arg} value {getattr(margs, arg)} with value from checkpoint {value}.")
setattr(margs, arg, value)
validate_args(margs)
set_global_variables(margs)
set_global_variables(margs, build_tokenizer=False)
# margs = megatron args
margs = get_args()
......@@ -164,7 +200,9 @@ def save_checkpoint(queue, args):
#-----------
embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings")
pos_embed = None
if md.position_embedding_type == 'learned_absolute':
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
......@@ -203,9 +241,11 @@ def save_checkpoint(queue, args):
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
if pos_embed is not None:
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
else:
assert not hasattr(model.language_model.embedding, "position_embeddings")
# Transformer layers
#-------------------
......@@ -223,34 +263,51 @@ def save_checkpoint(queue, args):
# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Special handling for swiglu
if md.swiglu:
mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
else:
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
if md.linear_bias:
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
if md.swiglu:
mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
else:
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
if md.linear_bias:
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
......@@ -262,13 +319,24 @@ def save_checkpoint(queue, args):
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0:
if pp_rank != 0 and not md.output_layer:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
check_message(msg)
if md.output_layer:
msg = queue_get("output layer")
if not hasattr(models[0].language_model, 'output_layer'):
print("ERROR: got an output layer, but model does not have one")
exit(1)
output_layer_weight = torch.chunk(msg.pop("weight"), args.target_tensor_parallel_size, dim=0)
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.output_layer.weight.data.copy_(output_layer_weight[tp_rank])
del output_layer_weight
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
......
......@@ -14,7 +14,7 @@ The following steps show how to prepare training dataset to train the mode.
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs.
```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
python blacklist_urls.py <path to the downloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
......@@ -37,7 +37,7 @@ python group_duplicate_urls.py <possible duplicate urls file> <output file conta
```
4. Remove similar documents that were detected in the last step.
```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
python remove_group_duplicates.py <file containing similar documents> <cleaned data file> <outputfile containing deduplicate data>
```
5. Shuffle the dataset.
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# WARNING! This file contains a blacklist of known malicious sites and thus contains some NSFW language.
import glob
......@@ -47,6 +49,7 @@ domain_blacklist = set([
'google',
'gunprime',
'gyazo',
'horsefucker',
'hotdealstar',
'imagefap',
'imageshack',
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing data for pretraining."""
"""Processing large data for pretraining."""
import argparse
import math
import json
import multiprocessing
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
import multiprocessing
try:
import nltk
nltk_available = True
......@@ -39,6 +41,7 @@ class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
......@@ -51,33 +54,129 @@ class Encoder(object):
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
print("loading: " + library)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params,
lang_vars=CustomLanguageVars())
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
doc_ids.append(Encoder.tokenizer.eod)
sentence_lens[-1] += 1
ids[key] = doc_ids
return ids, len(json_line)
lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
......@@ -94,20 +193,21 @@ def get_args():
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--vocab-size', default=786,
help='size of vocab for use with NullTokenizer')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
......@@ -116,84 +216,187 @@ def get_args():
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100,
help=('Number of worker processes to launch.'
'A good default for fast pre-processing '
'is: (workers * partitions) = available CPU cores.'))
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
group.add_argument('--keep-sequential-samples', action='store_true',
help='Ensure ordering of samples in .jsonl files is '
'preserved when using partitions>1.')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert'):
if not args.split_sentences:
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 0
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin)
# Count total number of lines across .jsonl files
if args.keep_sequential_samples:
total_sample_count = 0
for filename in in_file_names:
with open(filename, "r") as fin:
for fc, _ in enumerate(fin):
pass
total_sample_count += (fc + 1)
partition_size = math.ceil(total_sample_count / args.partitions)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
if args.keep_sequential_samples: line_count = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
if args.keep_sequential_samples:
line_count += 1
if line_count % partition_size == 0:
index += 1
else:
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
print("Done! Now finalizing.")
for key in args.json_keys:
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing large data for pretraining."""
import argparse
import math
import json
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
import multiprocessing
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids.append(Encoder.tokenizer.eod)
ids[key] = doc_ids
lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
# coding=utf-8
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Processing text modality data for MultiModal pretraining."""
import argparse
import json
import multiprocessing
import os
import sys
import numpy as np
from torchvision.transforms import ToTensor
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import torch
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
from megatron.data.indexed_dataset import MMapIndexedDatasetBuilder
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
def encode(self, input_pair):
json_line, img_file = input_pair
data = json.loads(json_line)
key = "text"
text = data[key]
sentence_ids = Encoder.tokenizer.tokenize(text)
pad_len = self.args.pad_length
if len(sentence_ids) > 0 and self.args.append_eod:
sentence_ids = sentence_ids[:pad_len]
current_length = len(sentence_ids)
sentence_ids.extend([Encoder.tokenizer.eod for _ in range(max(0,pad_len-current_length))])
with open(img_file[:-1], "rb") as tf:
xs = bytearray(tf.read())
img_pad = (4 - len(xs) % 4) % 4
xs.extend([0 for _ in range(img_pad)])
img_raw = np.frombuffer(xs, dtype=np.int32)
img_raw = np.insert(img_raw, 0, img_pad)
return sentence_ids, img_raw, len(json_line)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--input-image', type=str, required=True,
help='Path to input image folder')
group.add_argument('--pad-length', type=int, required=True,
help='Pad length of preprocessed text')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def main():
args = get_args()
startup_start = time.time()
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
fin = open(args.input + ".json", 'r', encoding='utf-8')
img_files = open(args.input_image)
encoded_docs = pool.imap(encoder.encode, zip(fin, img_files), 25)
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = "{}_mmdata.bin".format(args.output_prefix)
output_idx_files = "{}_mmdata.idx".format(args.output_prefix)
builders = MMapIndexedDatasetBuilder(output_bin_files, dtype=np.int32, multimodal=True)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (sentence, img_raw, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
builders.add_item(torch.IntTensor(sentence))
builders.add_item(torch.from_numpy(img_raw), 1)
builders.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
builders.finalize(output_idx_files)
if __name__ == '__main__':
main()
......@@ -18,13 +18,11 @@ The following overview goes into more detail on the pipeline, code structure, us
<!-- ################ quick start ################ -->
# Quick start
See `examples/get_preprocess_cmd.sh` for example arguments.
Key files:
- `main.py` : Entry point.
- `examples/get_preprocess_cmd.sh` : Build preprocessing command (for `main.py`).
- `examples/preprocess_data.sh` : Run preprocessing (calls `get_preprocess_cmd.sh`, `main.py`).
- `main.py` : Entry point for processing.
- `examples/preprocess_data.sh` : Example preprocessing launch (calls `main.py`).
- `examples/pretrain_data.sh` : Example pretraining launch (calls `pretrain_retro.py`).
Use `--retro-tasks` to move through the preprocessing pipeline.
......@@ -86,9 +84,8 @@ Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks
Example scripts for setting arguments and launch Retro preprocessing. The key files here are:
- **`get_preprocess_cmd.sh`** : Sets up arguments and command for preprocessing. **Important note**: this script assumes a few environment variables are already set before it is called. Please see the `Environment vars.` section at the top of this file. Generally, environment variables must be set to determine the location of Retro workdirs, input datasets, and GPT and Bert model information.
- **`preprocess_data.sh`** : Calls `get_preprocess_cmd.sh` to get arguments, and then calls `main.py` to launch preprocessing.
- **`pretrain_model.sh`** : Example script for pretraining on Wikipedia data, after preprocessing is complete.
- **`preprocess_data.sh`** : Example launch script for preprocessing retro data.
- **`pretrain_model.sh`** : Example launch script for pretraining a retro model.
### `tools/retro/db`
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.main import add_retro_args
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets()
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.pt_datasets.train.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][17][1]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][17][1]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
from .cli import retro
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment