"docs/vscode:/vscode.git/clone" did not exist on "3ee75292edb483abe2253596b98894616d7fe37c"
Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
......@@ -9,8 +9,8 @@ import numpy as np
import torch
from megatron import (core,
update_num_microbatches)
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
......@@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = core.get_tensor_model_parallel_rank()
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = core.get_pipeline_model_parallel_rank()
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
......@@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % core.get_data_parallel_rank(),
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
......@@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
core.get_data_parallel_world_size() > 1 and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(core.get_data_parallel_world_size())]
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=core.get_data_parallel_group())
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
......@@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or core.get_data_parallel_rank() == 0:
or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
model_state_dict['args'] = args
......@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i)
mpu.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
......@@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or core.get_data_parallel_rank() == 0
or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):
# Optimizer stuff.
......@@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i)
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
......@@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
......@@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(model_state_dict['random_rng_state'])
......@@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not model_state_dict['rng_tracker_states']:
raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
......@@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer,
release=False)
if core.get_data_parallel_rank() == 0:
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......@@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
if core.get_data_parallel_rank() == 0:
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
from .parallel_state import (
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
)
from megatron.core import tensor_parallel
import megatron.core.parallel_state
import megatron.core.tensor_parallel
import megatron.core.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]
......@@ -47,10 +47,6 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
......
......@@ -5,6 +5,7 @@ from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
......@@ -23,10 +24,14 @@ from .mappings import (
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
model_parallel_cuda_manual_seed,
)
from .utils import split_tensor_along_last_dim
from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [
# cross_entropy.py
......
......@@ -17,6 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_world_size,
)
from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
......@@ -55,38 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......
......@@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
......@@ -28,11 +32,64 @@ def split_tensor_along_last_dim(
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered
class VocabUtility:
"""Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
......
......@@ -21,35 +21,6 @@ def divide(numerator, denominator):
return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
......
......@@ -4,7 +4,8 @@ import time
import numpy as np
import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
......@@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
......
......@@ -8,8 +8,6 @@ import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset):
......
......@@ -8,7 +8,7 @@ import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
......
......@@ -28,9 +28,9 @@ import torch
from megatron import (
get_args,
mpu,
print_rank_0
)
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
......@@ -8,7 +8,8 @@ import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
......
......@@ -9,7 +9,8 @@ import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.core import tensor_parallel
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
......@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
......
......@@ -4,9 +4,10 @@ import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron import get_args, get_tokenizer, print_rank_0
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
......@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
......
......@@ -7,7 +7,7 @@ import numpy as np
import torch
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
def detach(tensor):
......@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def load_from_file(self):
"""Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
......@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
cpu_index = faiss.IndexFlatIP(self.embed_size)
......@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU", flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it
......@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
......
......@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
......
......@@ -14,13 +14,10 @@ from megatron import fused_kernels
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron import core
from megatron.core import mpu, tensor_parallel
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
......@@ -65,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args()
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank)
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
......@@ -147,7 +145,7 @@ def _compile_dependencies():
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
"""Initialize torch.distributed and core model parallel."""
args = get_args()
device_count = torch.cuda.device_count()
......@@ -185,17 +183,14 @@ def _initialize_distributed():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
core.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
print(f'> initialized tensor model parallel with size '
f'{core.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{core.get_pipeline_model_parallel_world_size()}')
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{mpu.get_pipeline_model_parallel_world_size()}')
def _init_autoresume():
......@@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
core.tensor_parallel.model_parallel_cuda_manual_seed(seed)
tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
......@@ -5,7 +5,7 @@
import torch
from megatron import get_args
from megatron import core
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
......@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......@@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = core.vocab_parallel_cross_entropy(lm_logits, lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = core.vocab_parallel_cross_entropy(lm_logits.float(),
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
......
......@@ -2,11 +2,11 @@ import os
import torch
import sys
from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_0, get_tokenizer
from megatron.core import mpu
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
......
......@@ -5,7 +5,6 @@
import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment