Commit b886b7bb authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

created megatron package

parent 4947002d
...@@ -22,9 +22,9 @@ import numpy as np ...@@ -22,9 +22,9 @@ import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
import mpu from megatron import mpu
import model from megatron import model
def print_rank_0(message): def print_rank_0(message):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import sys import sys
sys.path.append('..') sys.path.append('..')
from data_utils.tokenization_gpt2 import GPT2Tokenizer from megatron.data_utils.tokenization_gpt2 import GPT2Tokenizer
class Tokenizer: class Tokenizer:
......
...@@ -24,24 +24,24 @@ import torch.nn.functional as F ...@@ -24,24 +24,24 @@ import torch.nn.functional as F
from arguments import get_args from arguments import get_args
from configure_data import configure_data from configure_data import configure_data
from fp16 import FP16_Module from megatron.fp16 import FP16_Module
from fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from model import BertModel from megatron.model import BertModel
from model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from model import gpt2_get_params_for_weight_decay_optimization from megatron.model import gpt2_get_params_for_weight_decay_optimization
from model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from utils import Timers from megatron.utils import Timers
from utils import save_checkpoint from megatron.utils import save_checkpoint
from utils import load_checkpoint from megatron.utils import load_checkpoint
from utils import report_memory from megatron.utils import report_memory
from utils import print_args from megatron.utils import print_args
from utils import print_params_min_max_norm from megatron.utils import print_params_min_max_norm
from utils import print_rank_0 from megatron.utils import print_rank_0
from utils import enable_adlr_autoresume from megatron.utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
def get_model(args): def get_model(args):
"""Build the model.""" """Build the model."""
......
...@@ -24,23 +24,23 @@ import torch ...@@ -24,23 +24,23 @@ import torch
from arguments import get_args from arguments import get_args
from configure_data import configure_data from configure_data import configure_data
from fp16 import FP16_Module from megatron.fp16 import FP16_Module
from fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from model import GPT2Model from megatron.model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization from megatron.model import gpt2_get_params_for_weight_decay_optimization
from model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from utils import Timers from megatron.utils import Timers
from utils import save_checkpoint from megatron.utils import save_checkpoint
from utils import load_checkpoint from megatron.utils import load_checkpoint
from utils import report_memory from megatron.utils import report_memory
from utils import print_args from megatron.utils import print_args
from utils import print_params_min_max_norm from megatron.utils import print_params_min_max_norm
from utils import print_rank_0 from megatron.utils import print_rank_0
from utils import enable_adlr_autoresume from megatron.utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from gpt2_data_loader import make_gpt2_dataloaders from gpt2_data_loader import make_gpt2_dataloaders
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment