Commit b886b7bb authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

created megatron package

parent 4947002d
......@@ -22,9 +22,9 @@ import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from fp16 import FP16_Optimizer
import mpu
import model
from megatron.fp16 import FP16_Optimizer
from megatron import mpu
from megatron import model
def print_rank_0(message):
......
......@@ -16,7 +16,7 @@
import sys
sys.path.append('..')
from data_utils.tokenization_gpt2 import GPT2Tokenizer
from megatron.data_utils.tokenization_gpt2 import GPT2Tokenizer
class Tokenizer:
......
......@@ -24,24 +24,24 @@ import torch.nn.functional as F
from arguments import get_args
from configure_data import configure_data
from fp16 import FP16_Module
from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import BertModel
from model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization
from model import DistributedDataParallel as LocalDDP
import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import BertModel
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
from utils import save_checkpoint
from utils import load_checkpoint
from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
from megatron.utils import Timers
from megatron.utils import save_checkpoint
from megatron.utils import load_checkpoint
from megatron.utils import report_memory
from megatron.utils import print_args
from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
def get_model(args):
"""Build the model."""
......
......@@ -24,23 +24,23 @@ import torch
from arguments import get_args
from configure_data import configure_data
from fp16 import FP16_Module
from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization
from model import DistributedDataParallel as LocalDDP
import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
from utils import save_checkpoint
from utils import load_checkpoint
from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
from megatron.utils import Timers
from megatron.utils import save_checkpoint
from megatron.utils import load_checkpoint
from megatron.utils import report_memory
from megatron.utils import print_args
from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from gpt2_data_loader import make_gpt2_dataloaders
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment