Commit 2862a32a authored by dongcl's avatar dongcl
Browse files

fix flux import error

parent 23eb9b17
...@@ -165,7 +165,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -165,7 +165,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux # flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")): if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import ( from ..core.tensor_parallel.layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear FluxRowParallelLinear
) )
......
...@@ -12,8 +12,6 @@ from megatron.core.inference.contexts import BaseInferenceContext ...@@ -12,8 +12,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.core.utils import WrappedTensor, deprecate_inference_params
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn): def gpt_model_init_wrapper(fn):
@wraps(fn) @wraps(fn)
...@@ -25,6 +23,8 @@ def gpt_model_init_wrapper(fn): ...@@ -25,6 +23,8 @@ def gpt_model_init_wrapper(fn):
(self.post_process or self.mtp_process) (self.post_process or self.mtp_process)
and int(os.getenv("USE_FLUX_OVERLAP", "0")) and int(os.getenv("USE_FLUX_OVERLAP", "0"))
): ):
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear
self.output_layer = FluxColumnParallelLinear( self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size, self.config.hidden_size,
self.vocab_size, self.vocab_size,
......
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment