Commit de593298 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

allreduce layernorm fixes.

parent cf1c7848
......@@ -21,6 +21,7 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers
from megatron import mpu
......@@ -273,9 +274,8 @@ class MegatronOptimizer(ABC):
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
raise Exception("hi.")
grads = []
for model_module in model:
for model_module in self.models:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
......
......@@ -23,7 +23,6 @@ import time
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import get_signal_handler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment