Commit 72aeb0f3 authored by dongcl's avatar dongcl
Browse files

bug fix

parent d46a984e
...@@ -5,8 +5,6 @@ import types ...@@ -5,8 +5,6 @@ import types
import argparse import argparse
import torch import torch
from megatron.training import get_args
class MegatronAdaptation: class MegatronAdaptation:
""" """
...@@ -191,8 +189,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -191,8 +189,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True) apply_wrapper=True)
# flux # flux
args = get_args() if os.getenv("USE_FLUX_OVERLAP", 0):
if args.use_flux:
import flux import flux
from ..core.tensor_parallel import ( from ..core.tensor_parallel import (
......
...@@ -284,7 +284,7 @@ class AGLinear(torch.autograd.Function): ...@@ -284,7 +284,7 @@ class AGLinear(torch.autograd.Function):
) )
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_world_size(), batch_size, -1) grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
......
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
LNImpl = WrappedTorchNorm LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False): def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
""" """
Multi Token Predication Layer Specification. Multi Token Predication Layer Specification.
""" """
...@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False): ...@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False):
module=MultiTokenPredictor, module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules( submodules=MultiTokenPredicationSubmodules(
embedding=None, embedding=None,
enorm=TENorm if use_te else LNImpl, enorm=TENorm if use_te or use_flux else LNImpl,
hnorm=TENorm if use_te else LNImpl, hnorm=TENorm if use_te or use_flux else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear, eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer, transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl, final_layernorm=TENorm if use_te or use_flux else LNImpl,
output_layer=None, output_layer=None,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment