Commit 209f91c9 authored by Jared Casper's avatar Jared Casper
Browse files

Bring mpu.data into megatron.core.

parent 2e6a46e4
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
__all__ = [ __all__ = [
# cross_entropy.py # cross_entropy.py
"vocab_parallel_cross_entropy", "vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
] ]
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
import torch import torch
from .initialize import get_tensor_model_parallel_group from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_src_rank get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5 _MAX_DATA_DIM = 5
......
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
"""Model parallel utility interface.""" """Model parallel utility interface."""
from .data import broadcast_data
from .initialize import is_unitialized from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
......
...@@ -9,6 +9,7 @@ from megatron import print_rank_0 ...@@ -9,6 +9,7 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import core
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
...@@ -42,7 +43,7 @@ def get_batch(data_iterator): ...@@ -42,7 +43,7 @@ def get_batch(data_iterator):
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
data_b = mpu.broadcast_data(keys, data, datatype) data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens_ = data_b['text'].long() tokens_ = data_b['text'].long()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment