Commit 209f91c9 authored by Jared Casper's avatar Jared Casper
Browse files

Bring mpu.data into megatron.core.

parent 2e6a46e4
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
]
......@@ -2,9 +2,11 @@
import torch
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5
......
......@@ -2,9 +2,6 @@
"""Model parallel utility interface."""
from .data import broadcast_data
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
......
......@@ -9,6 +9,7 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import core
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
......@@ -42,7 +43,7 @@ def get_batch(data_iterator):
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment