Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
......@@ -10,7 +10,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType
from megatron.training import pretrain
......@@ -46,7 +46,7 @@ def get_batch(data_iterator):
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
......
......@@ -8,8 +8,7 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import core
from megatron.core import tensor_parallel
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
......@@ -43,7 +42,7 @@ def get_batch(data_iterator):
data = next(data_iterator)
else:
data = None
data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
......
......@@ -12,7 +12,7 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.core import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
......
......@@ -9,9 +9,9 @@ import torch
from megatron import (
get_args,
get_timers,
mpu,
print_rank_0
)
from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType
from megatron.training import pretrain
......@@ -80,7 +80,7 @@ def get_batch(data_iterator):
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_enc = data_b['text_enc'].long()
......
......@@ -5,7 +5,7 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
......
......@@ -6,7 +6,7 @@ import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
......
......@@ -5,7 +5,7 @@
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
......
......@@ -10,7 +10,7 @@ import torch
from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
......
......@@ -5,7 +5,6 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -6,10 +6,10 @@ import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.core import mpu
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
......
......@@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
......
......@@ -9,8 +9,8 @@ import math
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron import get_args, get_timers, get_tokenizer, print_rank_0
from megatron.core import mpu
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -13,7 +13,7 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
......
......@@ -5,7 +5,6 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args
from megatron import print_rank_0, print_rank_last
from megatron import mpu
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
......
......@@ -7,7 +7,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu, utils
from megatron import utils
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
......
......@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -6,7 +6,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args
from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel
from megatron.training import get_model
......@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric):
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy(
losses = mpu.tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment