Commit 72aeb0f3 authored by dongcl's avatar dongcl
Browse files

bug fix

parent d46a984e
......@@ -5,8 +5,6 @@ import types
import argparse
import torch
from megatron.training import get_args
class MegatronAdaptation:
"""
......@@ -191,8 +189,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
# flux
args = get_args()
if args.use_flux:
if os.getenv("USE_FLUX_OVERLAP", 0):
import flux
from ..core.tensor_parallel import (
......
......@@ -284,7 +284,7 @@ class AGLinear(torch.autograd.Function):
)
torch.cuda.current_stream().synchronize()
grad_input = grad_input.view(sequence_len // get_tensor_model_parallel_world_size(), batch_size, -1)
grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
else:
grad_input = grad_output.matmul(weight)
......
......@@ -30,7 +30,7 @@ except ImportError:
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
"""
Multi Token Predication Layer Specification.
"""
......@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False):
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
enorm=TENorm if use_te or use_flux else LNImpl,
hnorm=TENorm if use_te or use_flux else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
final_layernorm=TENorm if use_te or use_flux else LNImpl,
output_layer=None,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment