Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
...@@ -10,7 +10,7 @@ import torch.nn.functional as F ...@@ -10,7 +10,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType from megatron.model import BertModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -46,7 +46,7 @@ def get_batch(data_iterator): ...@@ -46,7 +46,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens = data_b['text'].long() tokens = data_b['text'].long()
......
...@@ -8,8 +8,7 @@ from megatron import get_args ...@@ -8,8 +8,7 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron.core import tensor_parallel
from megatron import core
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -43,7 +42,7 @@ def get_batch(data_iterator): ...@@ -43,7 +42,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = core.tensor_parallel.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens_ = data_b['text'].long() tokens_ = data_b['text'].long()
......
...@@ -12,7 +12,7 @@ import torch.nn.functional as F ...@@ -12,7 +12,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron.core import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType from megatron.model import ModelType
......
...@@ -9,9 +9,9 @@ import torch ...@@ -9,9 +9,9 @@ import torch
from megatron import ( from megatron import (
get_args, get_args,
get_timers, get_timers,
mpu,
print_rank_0 print_rank_0
) )
from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType from megatron.model import T5Model, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -80,7 +80,7 @@ def get_batch(data_iterator): ...@@ -80,7 +80,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens_enc = data_b['text_enc'].long() tokens_enc = data_b['text_enc'].long()
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel from megatron.model.vision.classification import VitClassificationModel
......
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel from megatron.model.vision.inpainting import MitInpaintingModel
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_last, is_last_rank from megatron import print_rank_last, is_last_rank
from megatron import mpu from megatron.core import mpu
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from megatron import get_args, get_num_microbatches from megatron import get_args, get_num_microbatches
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron.core import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType from megatron.model import ModelType
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
......
...@@ -6,10 +6,10 @@ import json ...@@ -6,10 +6,10 @@ import json
import torch import torch
import requests import requests
from nltk import word_tokenize from nltk import word_tokenize
from megatron import mpu
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.core import mpu
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
......
...@@ -10,7 +10,7 @@ import torch.nn.functional as F ...@@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron import mpu from megatron.core import mpu
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
......
...@@ -9,8 +9,8 @@ import math ...@@ -9,8 +9,8 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer from megatron import get_args, get_timers, get_tokenizer, print_rank_0
from megatron import mpu, print_rank_0 from megatron.core import mpu
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu from megatron import print_rank_0, get_args, get_tokenizer
from megatron.data.biencoder_dataset_utils import make_attention_mask from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split): def get_nq_dataset(qa_data, split):
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0, print_rank_last from megatron import print_rank_0, print_rank_last
from megatron import mpu from megatron.core import mpu
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch from tasks.vision.finetune_utils import process_batch
......
...@@ -7,7 +7,8 @@ import torch.nn.functional as F ...@@ -7,7 +7,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu, utils from megatron import utils
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0, is_last_rank from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron.core import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
...@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric): ...@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss. # For loss, return the unreduced loss.
if eval_metric == 'loss': if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy( losses = mpu.tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous()) output.contiguous().float(), labels.contiguous())
loss = torch.sum( loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float()) losses.view(-1) * loss_mask.contiguous().view(-1).float())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment