Commit 2e6a46e4 authored by Jared Casper's avatar Jared Casper
Browse files

Start Megatron-Core with vocab parallel cross entropy

Also merged in some changed from apex
parent 28ba253f
from .parallel_state import (
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_pipeline_model_parallel_world_size,
get_data_parallel_world_size,
)
from megatron.core import tensor_parallel
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups.""" """Model and data parallel groups."""
import torch import torch
from typing import Optional
from .utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
...@@ -52,10 +49,12 @@ def is_unitialized(): ...@@ -52,10 +49,12 @@ def is_unitialized():
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(
pipeline_model_parallel_size_=1, tensor_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size_=None, pipeline_model_parallel_size: int = 1,
pipeline_model_parallel_split_rank_=None): virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None:
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -67,7 +66,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -67,7 +66,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_split_rank: for models with both encoder and decoder, pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point. rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will the model pipeline. The present function will
...@@ -84,49 +82,44 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -84,49 +82,44 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0:
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
ensure_divisibility(world_size, raise RuntimeError(
tensor_model_parallel_size * pipeline_model_parallel_size) f"world_size ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
data_parallel_size = world_size // (tensor_model_parallel_size * )
pipeline_model_parallel_size)
data_parallel_size: int = world_size // (tensor_model_parallel_size *
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size pipeline_model_parallel_size)
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size_ is not None: num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank_ is not None: if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
'data parallel group is already initialized'
all_data_parallel_group_ranks = [] all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size): for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size): for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks)) all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
...@@ -135,8 +128,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -135,8 +128,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the model-parallel groups. # Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \ assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
'model parallel group is already initialized'
for i in range(data_parallel_size): for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks] for data_parallel_group_ranks in all_data_parallel_group_ranks]
...@@ -163,15 +155,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -163,15 +155,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \ assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized' 'position embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, ranks = range(i, world_size, num_pipeline_model_parallel_groups)
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
...@@ -516,7 +506,6 @@ def get_data_parallel_rank(): ...@@ -516,7 +506,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
...@@ -531,3 +520,15 @@ def destroy_model_parallel(): ...@@ -531,3 +520,15 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None _POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
from .cross_entropy import vocab_parallel_cross_entropy
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from .initialize import get_tensor_model_parallel_group from megatron.core.parallel_state import (
from .initialize import get_tensor_model_parallel_rank get_tensor_model_parallel_group,
from .initialize import get_tensor_model_parallel_world_size get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from .utils import VocabUtility from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function): class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, vocab_parallel_logits, target): def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
# Maximum value along vocab dimension across all GPUs. # Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
...@@ -62,8 +64,32 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -62,8 +64,32 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass. # Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss return loss
...@@ -89,9 +115,20 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -89,9 +115,20 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1)) grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None return grad_input, None, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Arguments:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
def vocab_parallel_cross_entropy(vocab_parallel_logits, target): lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
"""Helper function for the cross entropy.""" default is no smoothing (=0.0)
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) """
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from typing import List, Sequence
from megatron.core.utils import divide
def ensure_divisibility(numerator, denominator): def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
...@@ -17,8 +18,11 @@ def divide(numerator, denominator): ...@@ -17,8 +18,11 @@ def divide(numerator, denominator):
return numerator // denominator return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, def split_tensor_along_last_dim(
contiguous_split_chunks=False): tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension. """Split a tensor along its last dimension.
Arguments: Arguments:
tensor: input tensor. tensor: input tensor.
...@@ -39,19 +43,21 @@ def split_tensor_along_last_dim(tensor, num_partitions, ...@@ -39,19 +43,21 @@ def split_tensor_along_last_dim(tensor, num_partitions,
class VocabUtility: class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the """Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank` first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)""" partition: Note that indices in [fist, last)"""
@staticmethod @staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, def vocab_range_from_per_partition_vocab_size(
rank, world_size): per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size index_l = index_f + per_partition_vocab_size
return index_f, index_l return index_f, index_l
@staticmethod @staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size( return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size) per_partition_vocab_size, rank, world_size
)
"""Utility functions used through Megatron core"""
import torch
from megatron.core import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
...@@ -15,6 +15,7 @@ from megatron import get_adlr_autoresume ...@@ -15,6 +15,7 @@ from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import core
from megatron.arguments import (parse_args, validate_args) from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
...@@ -187,6 +188,14 @@ def _initialize_distributed(): ...@@ -187,6 +188,14 @@ def _initialize_distributed():
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank) args.pipeline_model_parallel_split_rank)
core.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
print(f'> initialized tensor model parallel with size '
f'{core.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{core.get_pipeline_model_parallel_world_size()}')
def _init_autoresume(): def _init_autoresume():
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import core
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -110,10 +110,10 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -110,10 +110,10 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b] # lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = core.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = core.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s, b] => [b s] # [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous() lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits return lm_loss, binary_logits
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import core
from .module import MegatronModule from .module import MegatronModule
from .enums import AttnMaskType from .enums import AttnMaskType
...@@ -33,9 +34,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -33,9 +34,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels = labels.transpose(0,1).contiguous() labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = core.tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) loss = core.tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s] # [s b] => [b, s]
loss = loss.transpose(0,1).contiguous() loss = loss.transpose(0,1).contiguous()
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""Model parallel utility interface.""" """Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment