Commit 929c780c authored by mohammad's avatar mohammad
Browse files

Merge branch 'main' into log_grad_norm

parents 286c4d97 577ad7d3
...@@ -30,7 +30,6 @@ def import_layernorm(fp32_residual_connection): ...@@ -30,7 +30,6 @@ def import_layernorm(fp32_residual_connection):
from .distributed import * from .distributed import *
from .vit_model import VitModel
from .bert_model import (BertModel, from .bert_model import (BertModel,
BertModelFirstStage, BertModelFirstStage,
BertModelIntermediateStage, BertModelIntermediateStage,
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import VitModel from megatron.model.vit_model import VitModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model import VitModel from megatron.model.vit_model import VitModel
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import finetune
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment