Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
b9a97686
"vscode:/vscode.git/clone" did not exist on "58a09708539b5d3c8c12f3aaceb18178e1483d16"
Commit
b9a97686
authored
Apr 10, 2025
by
dongcl
Browse files
support flux
parent
9eb8683b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
7 deletions
+15
-7
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+1
-1
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+13
-6
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+1
-0
No files found.
dcu_megatron/core/tensor_parallel/__init__.py
View file @
b9a97686
from
.layers
import
(
from
.layers
import
(
parallel_linear_init_wrapper
parallel_linear_init_wrapper
,
ColumnParallelLinearPatch
,
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
RowParallelLinearPatch
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_forward
,
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
b9a97686
from
typing
import
Callable
import
os
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
import
flux
import
flux
import
torch
import
torch
...
@@ -20,11 +23,18 @@ from megatron.core.tensor_parallel.layers import (
...
@@ -20,11 +23,18 @@ from megatron.core.tensor_parallel.layers import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
megatron.core.tensor_parallel.mappings
import
(
from
megatron.core.tensor_parallel.mappings
import
(
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel.layers
import
(
custom_fwd
,
custom_bwd
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
_grad_accum_fusion_available
=
True
_grad_accum_fusion_available
=
True
try
:
try
:
...
@@ -32,8 +42,6 @@ try:
...
@@ -32,8 +42,6 @@ try:
except
ImportError
:
except
ImportError
:
_grad_accum_fusion_available
=
False
_grad_accum_fusion_available
=
False
from
flux.cpp_mod
import
ReduceScatterOption
def
vocab_parallel_embedding_init
(
def
vocab_parallel_embedding_init
(
self
,
self
,
...
@@ -351,7 +359,7 @@ class AGLinear(torch.autograd.Function):
...
@@ -351,7 +359,7 @@ class AGLinear(torch.autograd.Function):
if
ctx
.
allreduce_dgrad
:
if
ctx
.
allreduce_dgrad
:
handle
.
wait
()
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
def
ag_linear
(
def
ag_linear
(
...
@@ -652,7 +660,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -652,7 +660,7 @@ class LinearRS(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
def
linear_rs
(
def
linear_rs
(
...
@@ -863,7 +871,6 @@ class ColumnParallelLinearPatch(torch.nn.Module):
...
@@ -863,7 +871,6 @@ class ColumnParallelLinearPatch(torch.nn.Module):
else
:
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad
=
False
if
self
.
explicit_expert_comm
else
self
.
allreduce_dgrad
allreduce_dgrad
=
False
if
self
.
explicit_expert_comm
else
self
.
allreduce_dgrad
forward_params
=
{
forward_params
=
{
...
...
dcu_megatron/training/arguments.py
View file @
b9a97686
...
@@ -80,6 +80,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -80,6 +80,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser
=
_add_ft_package_args
(
parser
)
parser
=
_add_ft_package_args
(
parser
)
parser
=
_add_config_logger_args
(
parser
)
parser
=
_add_config_logger_args
(
parser
)
parser
=
_add_rerun_machine_args
(
parser
)
parser
=
_add_rerun_machine_args
(
parser
)
parser
=
_add_flux_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment