Commit 2862a32a authored by dongcl's avatar dongcl
Browse files

fix flux import error

parent 23eb9b17
......@@ -165,7 +165,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import (
from ..core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
......
......@@ -12,8 +12,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
......@@ -25,6 +23,8 @@ def gpt_model_init_wrapper(fn):
(self.post_process or self.mtp_process)
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
......
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment