Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
import torch import torch
from tests.test_utilities import Utils from tests.unit_tests.test_utilities import Utils
from megatron.core import ModelParallelConfig
import megatron.core.pipeline_parallel.schedules as schedule import megatron.core.pipeline_parallel.schedules as schedule
from pytest_mock import mocker from pytest_mock import mocker
import pytest import pytest
...@@ -20,8 +21,8 @@ def test_get_forward_backward_func(): ...@@ -20,8 +21,8 @@ def test_get_forward_backward_func():
def test_deallocate_output_tensor(): def test_deallocate_output_tensor():
out = torch.tensor([[1, 2, 3], [4, 5, 6]]) out = torch.tensor([[1, 2, 3], [4, 5, 6]])
schedule.deallocate_output_tensor(out) schedule.deallocate_output_tensor(out)
assert(out.nelement() == 1) assert(out.nelement() == 6)
"""
def test_forward_backward_func_without_pipeline_parallel(mocker): def test_forward_backward_func_without_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
...@@ -45,12 +46,18 @@ def test_forward_backward_func_without_pipeline_parallel(mocker): ...@@ -45,12 +46,18 @@ def test_forward_backward_func_without_pipeline_parallel(mocker):
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining) assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
config = ModelParallelConfig(
pipeline_model_parallel_size = 1
)
model.config = config
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func=forward_step_func, forward_step_func=forward_step_func,
data_iterator=None, data_iterator=None,
model=[model], model=[model],
num_microbatches=4, num_microbatches=4,
seq_length=None,
micro_batch_size=None,
forward_only=False) forward_only=False)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}] loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
...@@ -83,6 +90,12 @@ def test_forward_backward_func_with_pipeline_parallel(mocker): ...@@ -83,6 +90,12 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
sequence_length = 512 sequence_length = 512
micro_batch_size = 8 micro_batch_size = 8
hidden_size = 256 hidden_size = 256
config = ModelParallelConfig(
pipeline_model_parallel_size = 4,
sequence_parallel = False
)
model.config = config
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func=forward_step_func, forward_step_func=forward_step_func,
...@@ -90,9 +103,8 @@ def test_forward_backward_func_with_pipeline_parallel(mocker): ...@@ -90,9 +103,8 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
dtype=torch.float32, dtype=torch.float32,
model=[model], model=[model],
num_microbatches= micro_batch_size, num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size], seq_length=sequence_length,
decoder_seq_length=sequence_length, micro_batch_size=micro_batch_size,
sequence_parallel=False,
forward_only=True) forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}] loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
...@@ -101,7 +113,7 @@ def test_forward_backward_func_with_pipeline_parallel(mocker): ...@@ -101,7 +113,7 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
assert(i['loss_reduced'] == j['loss_reduced']) assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel() Utils.destroy_model_parallel()
"""
def test_forward_backward_func_with_interleaving(mocker): def test_forward_backward_func_with_interleaving(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
...@@ -186,4 +198,4 @@ def test_forward_backward_func_with_interleaving(mocker): ...@@ -186,4 +198,4 @@ def test_forward_backward_func_with_interleaving(mocker):
assert(i['loss_reduced'] == j['loss_reduced']) assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel() Utils.destroy_model_parallel()
""" """
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.attention import SelfAttention
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelAttention:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_attention = SelfAttention(self.transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
assert isinstance(self.parallel_attention, SelfAttention)
assert self.parallel_attention.layer_number == 1
num_weights = sum([p.numel() for p in self.parallel_attention.parameters()])
assert num_weights == 648
def test_cpu_forward(self):
# we can't currently do this because the global memory buffer is on GPU
pass
def test_gpu_forward(self):
config = self.parallel_attention.config
sequence_length = 32
micro_batch_size = 2
self.parallel_attention.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
output, bias = self.parallel_attention(hidden_states, attention_mask)
assert config.recompute_granularity is None
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size
def test_checkpointed_gpu_forward(self):
transformer_config = self.transformer_config
transformer_config.recompute_granularity='selective'
checkpointed_parallel_attention = SelfAttention(transformer_config)
config = checkpointed_parallel_attention.config
sequence_length = 32
micro_batch_size = 2
checkpointed_parallel_attention.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones(
(sequence_length, micro_batch_size, checkpointed_parallel_attention.config.hidden_size)
)
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
output, bias = checkpointed_parallel_attention(hidden_states, attention_mask)
assert config.recompute_granularity == 'selective'
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.attention import CrossAttention
"""
@pytest.fixture
def core_attention(transformer_config):
return CrossAttention(transformer_config)
class TestCoreAttention:
def test_constructor(self, core_attention):
assert isinstance(core_attention, CrossAttention)
assert core_attention.layer_number == 1
num_weights = sum([p.numel() for p in core_attention.parameters()])
assert num_weights == 0
def test_cpu_forward(self, core_attention):
# we can't currently do this because the global memory buffer is on GPU
pass
def test_gpu_forward(self, core_attention):
# destroy_global_memory_buffer()
# _set_global_memory_buffer()
# model_parallel_cuda_manual_seed(123)
core_attention.cuda()
config = core_attention.config
sequence_length = 32
micro_batch_size = 2
# query_layer (float): [sequence_length, micro_batch_size, num_attention_heads, hidden_size / num_attention_heads]
query_layer = torch.ones(
(
sequence_length,
micro_batch_size,
config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
)
).cuda()
key_layer = torch.ones_like(query_layer).cuda()
value_layer = torch.ones_like(query_layer).cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
context_layer = core_attention(
query_layer=query_layer, key_layer=key_layer, value_layer=value_layer, attention_mask=attention_mask
)
assert context_layer.shape[0] == sequence_length
assert context_layer.shape[1] == micro_batch_size
assert context_layer.shape[2] == config.hidden_size
assert context_layer.device.type == 'cuda'
assert context_layer.dtype == torch.float32
"""
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.mlp import MLP
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelMLP:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.mlp = MLP(transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
assert isinstance(self.mlp, MLP)
num_weights = sum([p.numel() for p in self.mlp.parameters()])
assert num_weights == 1236
"""
def test_cpu_forward(self, mlp):
# [sequence length, micro batch size, hidden size]
hidden_states = torch.ones((32, 2, mlp.config.hidden_size))
output, output_bias = mlp(hidden_states)
assert output.shape[0] == 32
assert output.shape[1] == 2
assert output.shape[2] == mlp.config.hidden_size
assert output_bias.shape[0] == mlp.config.hidden_size
assert output.dtype == torch.float32
"""
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_gpu_forward(self):
mlp = self.mlp
mlp.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((32, 2, mlp.config.hidden_size))
hidden_states = hidden_states.cuda()
output, output_bias = mlp(hidden_states)
assert output.shape[0] == 32
assert output.shape[1] == 2
assert output.shape[2] == mlp.config.hidden_size
assert output_bias.shape[0] == mlp.config.hidden_size
assert output.dtype == torch.float32
assert output.device.type == 'cuda'
assert output_bias.device.type == 'cuda'
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.module import Float16Module, MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()
class DummyModule(MegatronModule):
# def __init__(self, config: TransformerConfig, share_embeddings_and_output_weights=True):
def __init__(self, config: TransformerConfig):
super().__init__(config)
self.linear = torch.nn.modules.Linear(in_features=2, out_features=1)
def forward(self, x):
return self.linear(x)
class TestMegatronModule:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.megatron_module = DummyModule(config=transformer_config).cuda()
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_megatron_module(self):
megatron_module = self.megatron_module
assert megatron_module
assert megatron_module.config.hidden_size == 12
assert megatron_module.config.ffn_hidden_size == 48
assert megatron_module.linear.weight.dtype == torch.float32
x = torch.ones((2, 2)).cuda()
assert megatron_module(x).dtype == torch.float32
# TODO: test bad configs actually fail
# failed_module = megatron_module
# failed_module.fp16 = True
# failed_module.bf16 = True
class TestFloat16Module:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.megatron_module = DummyModule(config=self.transformer_config).cuda()
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_fp16_module(self):
transformer_config = self.transformer_config
megatron_module = self.megatron_module
transformer_config.fp16 = True
fp16_module = Float16Module(config=transformer_config, module=megatron_module)
assert fp16_module
assert fp16_module.config.hidden_size == 12
assert fp16_module.config.ffn_hidden_size == 48
assert fp16_module.module.linear.weight.dtype == torch.float16
x = torch.ones((2, 2)).cuda()
# inputs are converted to fp16 then outputs are converted to fp32
assert fp16_module(x).dtype == torch.float32
pytest.mark.skipif(
not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='bfloat16 is not supported on this device'
)
def test_bf16_module(self):
transformer_config = self.transformer_config
megatron_module = self.megatron_module
transformer_config.bf16 = True
bf16_module = Float16Module(config=transformer_config, module=megatron_module)
assert bf16_module
assert bf16_module.config.hidden_size == 12
assert bf16_module.config.ffn_hidden_size == 48
assert bf16_module.module.linear.weight.dtype == torch.bfloat16
x = torch.ones((2, 2)).cuda()
# inputs are converted to bf16 then outputs are converted to fp32
assert bf16_module(x).dtype == torch.float32
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
import pytest
import torch
from megatron.core import dist_checkpointing
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.transformer_block import TransformerBlock
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
class TestParallelTransformerBlock:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_transformer_block = TransformerBlock(self.transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
parallel_transformer_block = self.parallel_transformer_block
assert isinstance(parallel_transformer_block, TransformerBlock)
num_weights = sum([p.numel() for p in parallel_transformer_block.parameters()])
assert num_weights == 3792
assert parallel_transformer_block.num_layers_per_pipeline_rank == 2
assert len(parallel_transformer_block.layers) == 2
layer_0: TransformerLayer = parallel_transformer_block._get_layer(0)
assert layer_0.layer_number == 1
layer_1: TransformerLayer = parallel_transformer_block._get_layer(1)
assert layer_1.layer_number == 2
def test_gpu_forward(self):
parallel_transformer_block = self.parallel_transformer_block
config: TransformerConfig = parallel_transformer_block.config
sequence_length = 32
micro_batch_size = 2
parallel_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = parallel_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
def test_gpu_forward_full_checkpoint(self):
transformer_config = self.transformer_config
config = transformer_config
config.recompute_granularity = 'full'
config.recompute_method = 'block'
config.recompute_num_layers = config.num_layers
full_transformer_block = TransformerBlock(config)
assert full_transformer_block.config.recompute_granularity == 'full'
assert full_transformer_block.config.recompute_method == 'block'
sequence_length = 32
micro_batch_size = 2
full_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = full_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
def test_gpu_forward_selective_checkpoint(self):
transformer_config = self.transformer_config
config = transformer_config
config.recompute_granularity = 'selective'
selective_transformer_block = TransformerBlock(config)
assert selective_transformer_block.config.recompute_granularity == 'selective'
assert selective_transformer_block.checkpoint_core_attention
sequence_length = 32
micro_batch_size = 2
selective_transformer_block.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = selective_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import pytest
import torch
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from tests.unit_tests.test_utilities import Utils
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
class TestParallelTransformerLayer:
def setup_method(self, method):
Utils.initialize_model_parallel(1,1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
self.parallel_transformer_layer = TransformerLayer(transformer_config)
def teardown_method(self, method):
Utils.destroy_model_parallel()
def test_constructor(self):
parallel_transformer_layer = self.parallel_transformer_layer
assert isinstance(parallel_transformer_layer, TransformerLayer)
assert parallel_transformer_layer.layer_number == 1
num_weights = sum([p.numel() for p in parallel_transformer_layer.parameters()])
assert num_weights == 1884
def test_gpu_forward(self):
parallel_transformer_layer = self.parallel_transformer_layer
config: TransformerConfig = parallel_transformer_layer.config
sequence_length = 32
micro_batch_size = 2
parallel_transformer_layer.cuda()
# [sequence length, batch size, hidden size]
hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
hidden_states = parallel_transformer_layer(hidden_states=hidden_states, attention_mask=attention_mask)
assert hidden_states.shape[0] == sequence_length
assert hidden_states.shape[1] == micro_batch_size
assert hidden_states.shape[2] == config.hidden_size
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# for now we just format core
black ${SCRIPT_DIR}/../megatron/core
isort ${SCRIPT_DIR}/../megatron/core
...@@ -11,9 +11,10 @@ from tqdm import tqdm ...@@ -11,9 +11,10 @@ from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0 from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core from megatron import core
from megatron.arguments import core_transformer_config_from_args
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.model import BertModel from megatron.model import BertModel
from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from .dataset import BertEmbeddingDataset from .dataset import BertEmbeddingDataset
...@@ -28,8 +29,10 @@ def model_provider(pre_process=True, post_process=True): ...@@ -28,8 +29,10 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0(" > build Bert model.") print_rank_0(" > build Bert model.")
args = get_args() args = get_args()
config = core_transformer_config_from_args(args)
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel( model = BertModel(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head, add_binary_head=args.bert_binary_head,
parallel_output=True, parallel_output=True,
......
...@@ -104,14 +104,14 @@ def get_missing_blocks(workdir, n_samples, block_size, ...@@ -104,14 +104,14 @@ def get_missing_blocks(workdir, n_samples, block_size,
try: try:
f = h5py.File(path, "r") f = h5py.File(path, "r")
except: except:
raise Exception("unable to open/validate '%s'." % path) # raise Exception("unable to open/validate '%s'." % path)
os.remove(path) os.remove(path)
continue continue
try: try:
validate(f) validate(f)
except: except:
raise Exception("delete block file.") # raise Exception("delete block file '%s'." % path)
os.remove(path) os.remove(path)
finally: finally:
f.close() f.close()
...@@ -156,53 +156,38 @@ def get_missing_blocks_by_rank(workdir, n_samples, block_size, ...@@ -156,53 +156,38 @@ def get_missing_blocks_by_rank(workdir, n_samples, block_size,
return len(missing_blocks), rank_missing_blocks return len(missing_blocks), rank_missing_blocks
class IdPathMap: class BlockPathMap:
'''Maps indexes to the containing block path. '''Map an index to its containing block path.
This class optimizing the mapping of a large number of indexes to the The common use for this class is to have a directory of files containing
path of its containing block. For example, with block_size 1M, this class blocks of processed data, of uniform block size (e.g., 100k samples per
stores 1/1M as many (long) path strings, saving memory. file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
''' '''
def __init__(self, paths): @classmethod
self.paths = paths def from_dir(cls, _dir, block_size, ext="hdf5"):
self.path_index_map = {p:i for i,p in enumerate(paths)} '''Get list of block files, and create map.'''
self.id_index_map = {} assert os.path.isdir(_dir), f"directory not found, '{_dir}'."
return cls(sorted(glob.glob(_dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths, block_size):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [ int(i) for i in name.split("-") ]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self): def __str__(self):
return "%d paths; %d ids" % (len(self.paths), len(self.id_index_map)) return "%d paths" % len(self.block_path_map)
def add(self, id, path):
'''Map index to a path.'''
self.id_index_map[id] = self.path_index_map[path]
def __contains__(self, idx):
'''Index added to this object?'''
return idx in self.id_index_map
def __getitem__(self, idx): def __getitem__(self, idx):
'''Get path from index.''' '''Get block path from index.'''
return self.paths[self.id_index_map[idx]] block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
def path_to_range(path):
'''Parse start/end indexes from block path name (e.g., 00010-00011.hdf5 ->
(10, 11).'''
return tuple([
int(i) for i in os.path.splitext(
os.path.basename(path))[0].split("-")])
def get_index_path_map(_dir):
'''Map contained indexes to block file path (on disk).'''
paths = sorted(glob.glob(_dir + "/*.hdf5"))
# Build index-path map.
idx_path_map = IdPathMap(paths)
for path in paths:
start_idx, end_idx = path_to_range(path)
for idx in range(start_idx, end_idx):
idx_path_map.add(idx, path)
return idx_path_map
...@@ -55,7 +55,7 @@ def _load_checkpoint(queue, args): ...@@ -55,7 +55,7 @@ def _load_checkpoint(queue, args):
] ]
margs = parse_args() margs = parse_args()
margs = load_args_from_checkpoint(margs) margs, checkpoint_args = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care, # Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes # so trick it into thinking we are plenty of processes
...@@ -63,12 +63,15 @@ def _load_checkpoint(queue, args): ...@@ -63,12 +63,15 @@ def _load_checkpoint(queue, args):
margs = validate_args(margs) margs = validate_args(margs)
def check_for_arg(arg_name): def check_for_arg(arg_name, default=None):
if getattr(margs, arg_name, None) is None: if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") if default is not None:
print(f"Arguments: {margs}") setattr(margs, arg_name, default)
queue.put("exit") else:
exit(1) print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size') check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size') check_for_arg('pipeline_model_parallel_size')
...@@ -77,10 +80,13 @@ def _load_checkpoint(queue, args): ...@@ -77,10 +80,13 @@ def _load_checkpoint(queue, args):
check_for_arg('seq_length') check_for_arg('seq_length')
check_for_arg('num_attention_heads') check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings') check_for_arg('max_position_embeddings')
check_for_arg('position_embedding_type')
check_for_arg('tokenizer_type') check_for_arg('tokenizer_type')
check_for_arg('iteration') check_for_arg('iteration')
check_for_arg('bert_binary_head') check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype') check_for_arg('params_dtype')
check_for_arg('swiglu', False)
# Determine how to make our models # Determine how to make our models
if args.model_type == 'GPT': if args.model_type == 'GPT':
...@@ -97,18 +103,38 @@ def _load_checkpoint(queue, args): ...@@ -97,18 +103,38 @@ def _load_checkpoint(queue, args):
consumed_train_samples = None consumed_train_samples = None
consumed_valid_samples = None consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process): def get_models(count, dtype):
nonlocal consumed_train_samples nonlocal consumed_train_samples
nonlocal consumed_valid_samples nonlocal consumed_valid_samples
models = [] model_array_len = margs.virtual_pipeline_model_parallel_size
if model_array_len is None:
model_array_len = 1
models = [[] for _ in range(model_array_len)]
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
for rank in range(count): for rank in range(count):
mpu.set_tensor_model_parallel_rank(rank) mpu.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)] if margs.virtual_pipeline_model_parallel_size is not None:
model_ = []
for i in range(margs.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider(
pre_process=pre_process,
post_process=post_process
).to(dtype)
model_.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model_rank = 0
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0 margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0 margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None) load_checkpoint(model_, None, None)
assert(len(model_) == 1)
model_ = model_[0]
if consumed_train_samples is not None: if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples) assert(margs.consumed_train_samples == consumed_train_samples)
else: else:
...@@ -117,17 +143,14 @@ def _load_checkpoint(queue, args): ...@@ -117,17 +143,14 @@ def _load_checkpoint(queue, args):
assert(margs.consumed_valid_samples == consumed_valid_samples) assert(margs.consumed_valid_samples == consumed_valid_samples)
else: else:
consumed_valid_samples = margs.consumed_valid_samples consumed_valid_samples = margs.consumed_valid_samples
models.append(model_) for vp_rank in range(model_array_len):
models[vp_rank].append(model_[vp_rank])
return models return models
if margs.num_layers_per_virtual_pipeline_stage is not None: set_global_variables(margs, build_tokenizer=False)
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(margs)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs) fused_kernels.load(margs)
# Get true (non-padded) vocab size # Get true (non-padded) vocab size
...@@ -146,6 +169,9 @@ def _load_checkpoint(queue, args): ...@@ -146,6 +169,9 @@ def _load_checkpoint(queue, args):
# short aliases # short aliases
tp_size = margs.tensor_model_parallel_size tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
# metadata # metadata
md = types.SimpleNamespace() md = types.SimpleNamespace()
...@@ -159,15 +185,20 @@ def _load_checkpoint(queue, args): ...@@ -159,15 +185,20 @@ def _load_checkpoint(queue, args):
md.iteration = margs.iteration md.iteration = margs.iteration
md.params_dtype = margs.params_dtype md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.checkpoint_args = checkpoint_args
# Get first pipe stage # Get first pipe stage
mpu.set_pipeline_model_parallel_rank(0) mpu.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1 all_models = [get_models(tp_size, md.params_dtype)]
models = get_models(tp_size, md.params_dtype, True, post_process) models = all_models[0][0]
md.consumed_train_samples = consumed_train_samples md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples md.consumed_valid_samples = consumed_valid_samples
...@@ -180,59 +211,83 @@ def _load_checkpoint(queue, args): ...@@ -180,59 +211,83 @@ def _load_checkpoint(queue, args):
# Send embeddings # Send embeddings
message = { message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat( "word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)], [models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0) dim = 0)
} }
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = models[0].language_model.embedding.position_embeddings.weight.data
else:
assert not hasattr(models[0].language_model.embedding, 'position_embeddings')
queue_put("embeddings", message) queue_put("embeddings", message)
total_layer_num = 0 total_layer_num = 0
for pp_rank in range(pp_size): for vp_rank in range(vp_size):
if pp_rank > 0: mpu.set_virtual_pipeline_model_parallel_rank(vp_rank)
mpu.set_pipeline_model_parallel_rank(pp_rank) for pp_rank in range(pp_size):
post_process = pp_rank == pp_size - 1 if pp_rank > 0:
models = get_models(tp_size, md.params_dtype, False, post_process) mpu.set_pipeline_model_parallel_rank(pp_rank)
for layer_num in range(len(models[0].language_model.encoder.layers)): if vp_rank == 0:
message = {} all_models.append(get_models(tp_size, md.params_dtype))
models = all_models[pp_rank][vp_rank]
# Get non-parallel tensors from tp_rank 0 for layer_num in range(len(models[0].language_model.encoder.layers)):
layer = models[0].language_model.encoder.layers[layer_num] message = {}
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data # Get non-parallel tensors from tp_rank 0
message["dense bias"] = layer.self_attention.dense.bias.data layer = models[0].language_model.encoder.layers[layer_num]
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data message["input layernorm weight"] = layer.input_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data message["input layernorm bias"] = layer.input_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
# Grab all parallel tensors for this layer if md.linear_bias:
qkv_weight = [] message["dense bias"] = layer.self_attention.dense.bias.data
qkv_bias = [] message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
dense_weight = []
mlp_l0_weight = [] # Grab all parallel tensors for this layer
mlp_l0_bias = [] qkv_weight = []
mlp_l1_weight = [] qkv_bias = []
for tp_rank, model in enumerate(models): dense_weight = []
layer = model.language_model.encoder.layers[layer_num] mlp_l0_weight = []
qkv_weight.append(layer.self_attention.query_key_value.weight.data) mlp_l0_bias = []
qkv_bias.append(layer.self_attention.query_key_value.bias.data) mlp_l1_weight = []
dense_weight.append(layer.self_attention.dense.weight.data) for tp_rank, model in enumerate(models):
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) layer = model.language_model.encoder.layers[layer_num]
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) qkv_weight.append(layer.self_attention.query_key_value.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
# concat them mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
message["qkv weight"] = torch.cat(qkv_weight, dim=0) if md.linear_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0) qkv_bias.append(layer.self_attention.query_key_value.bias.data)
message["dense weight"] = torch.cat(dense_weight, dim=1) mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) # Handle gated linear units
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) if md.swiglu:
# concat all the first halves ('W's) and all the second halves ('V's)
queue_put(f"transformer layer {total_layer_num}", message) for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
total_layer_num = total_layer_num + 1 message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# simple concat of the rest
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0 # Send final layernorm from tp_rank 0
message = { message = {
...@@ -241,6 +296,15 @@ def _load_checkpoint(queue, args): ...@@ -241,6 +296,15 @@ def _load_checkpoint(queue, args):
} }
queue_put("final layernorm", message) queue_put("final layernorm", message)
if md.output_layer:
message = {
"weight": torch.cat(
[models[tp_rank].language_model.output_layer.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
queue_put("output layer", message)
# Send BERT lm head and binary head if it exists # Send BERT lm head and binary head if it exists
if md.model_type == 'BERT': if md.model_type == 'BERT':
message = { message = {
......
...@@ -96,6 +96,7 @@ def save_checkpoint(queue, args): ...@@ -96,6 +96,7 @@ def save_checkpoint(queue, args):
'--seq-length', str(md.seq_length), '--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads), '--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings), '--max-position-embeddings', str(md.max_position_embeddings),
'--position-embedding-type', str(md.position_embedding_type),
'--tokenizer-type', str(md.tokenizer_type), '--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size), '--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size), '--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
...@@ -121,12 +122,47 @@ def save_checkpoint(queue, args): ...@@ -121,12 +122,47 @@ def save_checkpoint(queue, args):
elif md.params_dtype == torch.bfloat16: elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16') sys.argv.append('--bf16')
if md.output_layer:
sys.argv.append('--untie-embeddings-and-output-weights')
if not md.linear_bias:
sys.argv.append('--disable-bias-linear')
if md.model_type == 'BERT' and not md.bert_binary_head: if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head') sys.argv.append('--bert-no-binary-head')
margs = parse_args() margs = parse_args()
if hasattr (md, 'checkpoint_args'):
# These are arguments that we are either changing, or cause problems for validation if they are set
# Note that some of these deal with T5 so will need to be changed if we support T5.
args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype',
'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size',
'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion',
'sequence_parallel', 'async_tensor_model_parallel_allreduce',
'no_load_optim', 'no_load_rng', 'no_save_optim', 'no_save_rng',
'vocab_file', 'tokenizer_model',
'save_interval', 'save',
'perform_initialization', 'use_cpu_initialization',
'encoder_num_layers', 'encoder_seq_length',
'distribute_saved_activations',
'train_iters', 'lr_decay_iters', 'lr_warmup_iters', 'lr_warmup_fraction',
'start_weight_decay', 'end_weight_decay']
for arg, value in vars(md.checkpoint_args).items():
if arg in args_to_keep:
continue
if not hasattr(margs, arg):
print(f"Checkpoint had argument {arg} but new arguments does not have this.")
continue
if getattr(margs, arg) != value:
print(f"Overwriting default {arg} value {getattr(margs, arg)} with value from checkpoint {value}.")
setattr(margs, arg, value)
validate_args(margs) validate_args(margs)
set_global_variables(margs)
set_global_variables(margs, build_tokenizer=False)
# margs = megatron args # margs = megatron args
margs = get_args() margs = get_args()
...@@ -164,7 +200,9 @@ def save_checkpoint(queue, args): ...@@ -164,7 +200,9 @@ def save_checkpoint(queue, args):
#----------- #-----------
embeddings_msg = queue_get("embeddings") embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings") pos_embed = None
if md.position_embedding_type == 'learned_absolute':
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings") orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg) check_message(embeddings_msg)
...@@ -203,9 +241,11 @@ def save_checkpoint(queue, args): ...@@ -203,9 +241,11 @@ def save_checkpoint(queue, args):
post_process = args.target_pipeline_parallel_size == 1 post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process) models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models): for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed) if pos_embed is not None:
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
else:
assert not hasattr(model.language_model.embedding, "position_embeddings")
# Transformer layers # Transformer layers
#------------------- #-------------------
...@@ -223,34 +263,51 @@ def save_checkpoint(queue, args): ...@@ -223,34 +263,51 @@ def save_checkpoint(queue, args):
# duplicated tensors # duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight") input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias") input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight") post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias") post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias") if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors # Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1) dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1) mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Special handling for swiglu
if md.swiglu:
mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
else:
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
if md.linear_bias:
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
if md.swiglu:
mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
else:
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
# Save them to the model # Save them to the model
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer] l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight) l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias) l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight) l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) if md.linear_bias:
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1 total_layer_num = total_layer_num + 1
check_message(msg) check_message(msg)
...@@ -262,13 +319,24 @@ def save_checkpoint(queue, args): ...@@ -262,13 +319,24 @@ def save_checkpoint(queue, args):
for tp_rank in range(args.target_tensor_parallel_size): for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight) models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias) models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0: if pp_rank != 0 and not md.output_layer:
# Copy word embeddings to final pipeline rank # Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight del final_layernorm_weight
del final_layernorm_bias del final_layernorm_bias
check_message(msg) check_message(msg)
if md.output_layer:
msg = queue_get("output layer")
if not hasattr(models[0].language_model, 'output_layer'):
print("ERROR: got an output layer, but model does not have one")
exit(1)
output_layer_weight = torch.chunk(msg.pop("weight"), args.target_tensor_parallel_size, dim=0)
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.output_layer.weight.data.copy_(output_layer_weight[tp_rank])
del output_layer_weight
check_message(msg)
msg = queue_get() msg = queue_get()
if msg != "done" and msg["name"] == "pooler": if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'): if not hasattr(models[0].language_model, 'pooler'):
......
...@@ -14,7 +14,7 @@ The following steps show how to prepare training dataset to train the mode. ...@@ -14,7 +14,7 @@ The following steps show how to prepare training dataset to train the mode.
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ) 1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs. 2. Remove blacklisted URLs.
``` ```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt> python blacklist_urls.py <path to the downloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
``` ```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py). 3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
...@@ -37,7 +37,7 @@ python group_duplicate_urls.py <possible duplicate urls file> <output file conta ...@@ -37,7 +37,7 @@ python group_duplicate_urls.py <possible duplicate urls file> <output file conta
``` ```
4. Remove similar documents that were detected in the last step. 4. Remove similar documents that were detected in the last step.
``` ```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data> python remove_group_duplicates.py <file containing similar documents> <cleaned data file> <outputfile containing deduplicate data>
``` ```
5. Shuffle the dataset. 5. Shuffle the dataset.
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# WARNING! This file contains a blacklist of known malicious sites and thus contains some NSFW language.
import glob import glob
...@@ -47,6 +49,7 @@ domain_blacklist = set([ ...@@ -47,6 +49,7 @@ domain_blacklist = set([
'google', 'google',
'gunprime', 'gunprime',
'gyazo', 'gyazo',
'horsefucker',
'hotdealstar', 'hotdealstar',
'imagefap', 'imagefap',
'imageshack', 'imageshack',
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing data for pretraining.""" """Processing large data for pretraining."""
import argparse import argparse
import math
import json import json
import multiprocessing
import os import os
import sys import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
import time import time
import gzip
import glob
import torch import torch
import numpy as np
import multiprocessing
try: try:
import nltk import nltk
nltk_available = True nltk_available = True
...@@ -39,6 +41,7 @@ class IdentitySplitter(object): ...@@ -39,6 +41,7 @@ class IdentitySplitter(object):
def tokenize(self, *text): def tokenize(self, *text):
return text return text
class Encoder(object): class Encoder(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
...@@ -51,33 +54,129 @@ class Encoder(object): ...@@ -51,33 +54,129 @@ class Encoder(object):
print("NLTK is not available to split sentences.") print("NLTK is not available to split sentences.")
exit() exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang) library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
print("loading: " + library)
splitter = nltk.load(library) splitter = nltk.load(library)
if self.args.keep_newlines: if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences # this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params, train_text = splitter._params,
lang_vars=CustomLanguageVars()) lang_vars = CustomLanguageVars())
else: else:
Encoder.splitter = splitter Encoder.splitter = splitter
else: else:
Encoder.splitter = IdentitySplitter() Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line): def encode(self, json_line):
data = json.loads(json_line) data = json.loads(json_line)
ids = {} ids = {}
lens = {}
for key in self.args.json_keys: for key in self.args.json_keys:
text = data[key] text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = [] doc_ids = []
for sentence in Encoder.splitter.tokenize(text): sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence) sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0: if len(sentence_ids) > 0:
doc_ids.append(sentence_ids) doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod: if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod) doc_ids.append(Encoder.tokenizer.eod)
sentence_lens[-1] += 1
ids[key] = doc_ids ids[key] = doc_ids
return ids, len(json_line) lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -94,20 +193,21 @@ def get_args(): ...@@ -94,20 +193,21 @@ def get_args():
group = parser.add_argument_group(title='tokenizer') group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True, group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase', choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'], 'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'NullTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None, group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file') help='Path to the vocab file')
group.add_argument('--vocab-size', default=786,
help='size of vocab for use with NullTokenizer')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).') help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true', group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.') help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english', group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.') help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data') group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True, group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix') help='Path to binary output file without suffix')
...@@ -116,84 +216,187 @@ def get_args(): ...@@ -116,84 +216,187 @@ def get_args():
group = parser.add_argument_group(title='runtime') group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, required=True, group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch') help=('Number of worker processes to launch.'
group.add_argument('--chunk-size', type=int, required=True, 'A good default for fast pre-processing '
help='Chunk size assigned to each worker process') 'is: (workers * partitions) = available CPU cores.'))
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates') help='Interval between progress updates')
group.add_argument('--keep-sequential-samples', action='store_true',
help='Ensure ordering of samples in .jsonl files is '
'preserved when using partitions>1.')
args = parser.parse_args() args = parser.parse_args()
args.keep_empty = False args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert'): if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
if not args.split_sentences: print("Are you sure you don't want to split sentences?")
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 0 args.rank = 1
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0 args.vocab_extra_ids = 0
return args return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main(): def main():
args = get_args() args = get_args()
startup_start = time.time()
print("Opening", args.input) if args.split_sentences:
fin = open(args.input, 'r', encoding='utf-8') if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
if nltk_available and args.split_sentences: in_ss_out_names = []
nltk.download("punkt", quiet=True) if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
encoder = Encoder(args) # Count total number of lines across .jsonl files
tokenizer = build_tokenizer(args) if args.keep_sequential_samples:
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) total_sample_count = 0
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size) for filename in in_file_names:
#encoded_docs = map(encoder.encode, fin) with open(filename, "r") as fin:
for fc, _ in enumerate(fin):
pass
total_sample_count += (fc + 1)
partition_size = math.ceil(total_sample_count / args.partitions)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
if args.keep_sequential_samples: line_count = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
if args.keep_sequential_samples:
line_count += 1
if line_count % partition_size == 0:
index += 1
else:
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# merge bin/idx partitions
level = "document" level = "document"
if args.split_sentences: if args.split_sentences:
level = "sentence" level = "sentence"
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {} output_bin_files = {}
output_idx_files = {} output_idx_files = {}
builders = {} builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys: for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level) key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level) key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key], builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl, impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
startup_end = time.time() parition_output_prefix = name['output_prefix']
proc_start = time.time() full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
total_bytes_processed = 0 key, level)
print("Time to startup:", startup_end - startup_start) builders[key].merge_file_(full_partition_output_prefix)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
print("Done! Now finalizing.")
for key in args.json_keys:
builders[key].finalize(output_idx_files[key]) builders[key].finalize(output_idx_files[key])
if __name__ == '__main__': if __name__ == '__main__':
main() main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing large data for pretraining."""
import argparse
import math
import json
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
import multiprocessing
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids.append(Encoder.tokenizer.eod)
ids[key] = doc_ids
lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
# coding=utf-8
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Processing text modality data for MultiModal pretraining."""
import argparse
import json
import multiprocessing
import os
import sys
import numpy as np
from torchvision.transforms import ToTensor
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import torch
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
from megatron.data.indexed_dataset import MMapIndexedDatasetBuilder
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
def encode(self, input_pair):
json_line, img_file = input_pair
data = json.loads(json_line)
key = "text"
text = data[key]
sentence_ids = Encoder.tokenizer.tokenize(text)
pad_len = self.args.pad_length
if len(sentence_ids) > 0 and self.args.append_eod:
sentence_ids = sentence_ids[:pad_len]
current_length = len(sentence_ids)
sentence_ids.extend([Encoder.tokenizer.eod for _ in range(max(0,pad_len-current_length))])
with open(img_file[:-1], "rb") as tf:
xs = bytearray(tf.read())
img_pad = (4 - len(xs) % 4) % 4
xs.extend([0 for _ in range(img_pad)])
img_raw = np.frombuffer(xs, dtype=np.int32)
img_raw = np.insert(img_raw, 0, img_pad)
return sentence_ids, img_raw, len(json_line)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--input-image', type=str, required=True,
help='Path to input image folder')
group.add_argument('--pad-length', type=int, required=True,
help='Pad length of preprocessed text')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def main():
args = get_args()
startup_start = time.time()
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
fin = open(args.input + ".json", 'r', encoding='utf-8')
img_files = open(args.input_image)
encoded_docs = pool.imap(encoder.encode, zip(fin, img_files), 25)
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = "{}_mmdata.bin".format(args.output_prefix)
output_idx_files = "{}_mmdata.idx".format(args.output_prefix)
builders = MMapIndexedDatasetBuilder(output_bin_files, dtype=np.int32, multimodal=True)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (sentence, img_raw, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
builders.add_item(torch.IntTensor(sentence))
builders.add_item(torch.from_numpy(img_raw), 1)
builders.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
builders.finalize(output_idx_files)
if __name__ == '__main__':
main()
...@@ -18,13 +18,11 @@ The following overview goes into more detail on the pipeline, code structure, us ...@@ -18,13 +18,11 @@ The following overview goes into more detail on the pipeline, code structure, us
<!-- ################ quick start ################ --> <!-- ################ quick start ################ -->
# Quick start # Quick start
See `examples/get_preprocess_cmd.sh` for example arguments.
Key files: Key files:
- `main.py` : Entry point. - `main.py` : Entry point for processing.
- `examples/get_preprocess_cmd.sh` : Build preprocessing command (for `main.py`). - `examples/preprocess_data.sh` : Example preprocessing launch (calls `main.py`).
- `examples/preprocess_data.sh` : Run preprocessing (calls `get_preprocess_cmd.sh`, `main.py`). - `examples/pretrain_data.sh` : Example pretraining launch (calls `pretrain_retro.py`).
Use `--retro-tasks` to move through the preprocessing pipeline. Use `--retro-tasks` to move through the preprocessing pipeline.
...@@ -86,9 +84,8 @@ Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks ...@@ -86,9 +84,8 @@ Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks
Example scripts for setting arguments and launch Retro preprocessing. The key files here are: Example scripts for setting arguments and launch Retro preprocessing. The key files here are:
- **`get_preprocess_cmd.sh`** : Sets up arguments and command for preprocessing. **Important note**: this script assumes a few environment variables are already set before it is called. Please see the `Environment vars.` section at the top of this file. Generally, environment variables must be set to determine the location of Retro workdirs, input datasets, and GPT and Bert model information. - **`preprocess_data.sh`** : Example launch script for preprocessing retro data.
- **`preprocess_data.sh`** : Calls `get_preprocess_cmd.sh` to get arguments, and then calls `main.py` to launch preprocessing. - **`pretrain_model.sh`** : Example launch script for pretraining a retro model.
- **`pretrain_model.sh`** : Example script for pretraining on Wikipedia data, after preprocessing is complete.
### `tools/retro/db` ### `tools/retro/db`
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json from .cli import retro
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.main import add_retro_args
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets()
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.pt_datasets.train.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][17][1]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][17][1]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment