Commit b9a97686 authored by dongcl's avatar dongcl
Browse files

support flux

parent 9eb8683b
from .layers import ( from .layers import (
parallel_linear_init_wrapper parallel_linear_init_wrapper,
ColumnParallelLinearPatch, ColumnParallelLinearPatch,
RowParallelLinearPatch, RowParallelLinearPatch,
vocab_parallel_embedding_forward, vocab_parallel_embedding_forward,
......
from typing import Callable import os
import warnings
from functools import wraps
from typing import Callable, List, Optional
import flux import flux
import torch import torch
...@@ -20,11 +23,18 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -20,11 +23,18 @@ from megatron.core.tensor_parallel.layers import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from megatron.core.tensor_parallel.mappings import ( from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region,
) )
from megatron.core.tensor_parallel.utils import VocabUtility from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -32,8 +42,6 @@ try: ...@@ -32,8 +42,6 @@ try:
except ImportError: except ImportError:
_grad_accum_fusion_available = False _grad_accum_fusion_available = False
from flux.cpp_mod import ReduceScatterOption
def vocab_parallel_embedding_init( def vocab_parallel_embedding_init(
self, self,
...@@ -351,7 +359,7 @@ class AGLinear(torch.autograd.Function): ...@@ -351,7 +359,7 @@ class AGLinear(torch.autograd.Function):
if ctx.allreduce_dgrad: if ctx.allreduce_dgrad:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def ag_linear( def ag_linear(
...@@ -652,7 +660,7 @@ class LinearRS(torch.autograd.Function): ...@@ -652,7 +660,7 @@ class LinearRS(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def linear_rs( def linear_rs(
...@@ -863,7 +871,6 @@ class ColumnParallelLinearPatch(torch.nn.Module): ...@@ -863,7 +871,6 @@ class ColumnParallelLinearPatch(torch.nn.Module):
else: else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
forward_params = { forward_params = {
......
...@@ -80,6 +80,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -80,6 +80,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_ft_package_args(parser) parser = _add_ft_package_args(parser)
parser = _add_config_logger_args(parser) parser = _add_config_logger_args(parser)
parser = _add_rerun_machine_args(parser) parser = _add_rerun_machine_args(parser)
parser = _add_flux_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment