Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
ec7c8bc3
Commit
ec7c8bc3
authored
Apr 17, 2025
by
dongcl
Browse files
replace te with flux when using flux
parent
138b70a2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
296 additions
and
224 deletions
+296
-224
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+7
-17
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+13
-11
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+2
-4
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+274
-192
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
ec7c8bc3
...
@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
# flux
if
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
0
):
if
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
0
):
import
flux
from
..core.tensor_parallel
import
(
from
..core.tensor_parallel
import
(
ColumnParallelLinearPatch
,
FluxColumnParallelLinear
,
RowParallelLinearPatch
,
FluxRowParallelLinear
column_parallel_linear_init_wrapper
,
row_parallel_linear_init_wrapper
)
)
from
..core.models.gpt.gpt_layer_specs
import
get_gpt_layer_with_flux_spec
from
..core.models.gpt.gpt_layer_specs
import
get_gpt_layer_with_flux_spec
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__"
,
MegatronAdaptation
.
register
(
"megatron.core.extensions.transformer_engine.TEColumnParallelLinear"
,
column_parallel_linear_init_wrapper
,
FluxColumnParallelLinear
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.extensions.transformer_engine.TERowParallelLinear"
,
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward"
,
FluxRowParallelLinear
)
ColumnParallelLinearPatch
.
forward
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.__init__"
,
row_parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
def
patch_training
(
self
):
def
patch_training
(
self
):
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
ec7c8bc3
...
@@ -3,7 +3,6 @@ from typing import Optional
...
@@ -3,7 +3,6 @@ from typing import Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.identity_op
import
IdentityOp
...
@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import (
...
@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import (
TransformerLayer
,
TransformerLayer
,
TransformerLayerSubmodules
,
TransformerLayerSubmodules
,
)
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
from
megatron.core.utils
import
is_te_min_version
try
:
try
:
...
@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec(
...
@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec(
module
=
MLASelfAttention
,
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
ColumnParallelLinear
,
linear_q_proj
=
Flux
ColumnParallelLinear
,
linear_q_down_proj
=
ColumnParallelLinear
,
linear_q_down_proj
=
Flux
ColumnParallelLinear
,
linear_q_up_proj
=
ColumnParallelLinear
,
linear_q_up_proj
=
Flux
ColumnParallelLinear
,
linear_kv_down_proj
=
ColumnParallelLinear
,
linear_kv_down_proj
=
Flux
ColumnParallelLinear
,
linear_kv_up_proj
=
ColumnParallelLinear
,
linear_kv_up_proj
=
Flux
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
linear_proj
=
Flux
RowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
...
@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec(
...
@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec(
module
=
SelfAttention
,
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
linear_qkv
=
Flux
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
linear_proj
=
Flux
RowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
),
),
...
@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec(
...
@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec(
return
ModuleSpec
(
return
ModuleSpec
(
module
=
MLP
,
module
=
MLP
,
submodules
=
MLPSubmodules
(
submodules
=
MLPSubmodules
(
linear_fc1
=
ColumnParallelLinear
,
linear_fc1
=
Flux
ColumnParallelLinear
,
linear_fc2
=
RowParallelLinear
,
linear_fc2
=
Flux
RowParallelLinear
,
),
),
)
)
else
:
else
:
...
...
dcu_megatron/core/tensor_parallel/__init__.py
View file @
ec7c8bc3
from
.layers
import
(
from
.layers
import
(
column_parallel_linear_init_wrapper
,
FluxColumnParallelLinear
,
row_parallel_linear_init_wrapper
,
FluxRowParallelLinear
,
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
,
vocab_parallel_embedding_init
,
)
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
ec7c8bc3
...
@@ -5,9 +5,8 @@ from typing import Callable, List, Optional
...
@@ -5,9 +5,8 @@ from typing import Callable, List, Optional
try
:
try
:
import
flux
import
flux
HAS_FLUX
=
True
except
ImportError
:
except
ImportError
:
HAS_FLUX
=
False
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -39,6 +38,10 @@ from megatron.core.tensor_parallel.mappings import (
...
@@ -39,6 +38,10 @@ from megatron.core.tensor_parallel.mappings import (
)
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
)
from
megatron.core.tensor_parallel.layers
import
(
from
megatron.core.tensor_parallel.layers
import
(
custom_fwd
,
custom_fwd
,
custom_bwd
,
custom_bwd
,
...
@@ -218,6 +221,8 @@ class AGLinear(torch.autograd.Function):
...
@@ -218,6 +221,8 @@ class AGLinear(torch.autograd.Function):
output
=
output
.
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
output
=
output
.
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
else
:
else
:
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
return
output
...
@@ -232,7 +237,7 @@ class AGLinear(torch.autograd.Function):
...
@@ -232,7 +237,7 @@ class AGLinear(torch.autograd.Function):
transpose_weight
=
ctx
.
transpose_weight
transpose_weight
=
ctx
.
transpose_weight
bw_gemm_rs_op
=
ctx
.
bw_gemm_rs_op
bw_gemm_rs_op
=
ctx
.
bw_gemm_rs_op
wgrad_compute
=
True
wgrad_compute
=
weight
.
requires_grad
if
grad_output_buffer
is
not
None
:
if
grad_output_buffer
is
not
None
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
grad_output_buffer
.
append
(
grad_output
)
grad_output_buffer
.
append
(
grad_output
)
...
@@ -300,10 +305,14 @@ class AGLinear(torch.autograd.Function):
...
@@ -300,10 +305,14 @@ class AGLinear(torch.autograd.Function):
)
)
if
not
ctx
.
sequence_parallel
and
ctx
.
allreduce_dgrad
:
if
not
ctx
.
sequence_parallel
and
ctx
.
allreduce_dgrad
:
if
weight
.
requires_grad
:
# Asynchronous all-reduce
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
)
else
:
grad_input
=
_reduce
(
grad_input
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
if
ctx
.
gradient_accumulation_fusion
:
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
wgrad_compute
:
...
@@ -530,7 +539,6 @@ class LinearRS(torch.autograd.Function):
...
@@ -530,7 +539,6 @@ class LinearRS(torch.autograd.Function):
# output = _reduce_scatter_along_first_dim(output)
# output = _reduce_scatter_along_first_dim(output)
else
:
else
:
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
_reduce
(
output
)
return
output
return
output
...
@@ -545,7 +553,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -545,7 +553,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight
=
ctx
.
transpose_weight
transpose_weight
=
ctx
.
transpose_weight
bw_ag_gemm_op
=
ctx
.
bw_ag_gemm_op
bw_ag_gemm_op
=
ctx
.
bw_ag_gemm_op
wgrad_compute
=
True
wgrad_compute
=
weight
.
requires_grad
if
grad_output_buffer
is
not
None
:
if
grad_output_buffer
is
not
None
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
grad_output_buffer
.
append
(
grad_output
)
grad_output_buffer
.
append
(
grad_output
)
...
@@ -604,6 +612,9 @@ class LinearRS(torch.autograd.Function):
...
@@ -604,6 +612,9 @@ class LinearRS(torch.autograd.Function):
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
if
not
weight
.
requires_grad
:
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
if
ctx
.
sequence_parallel
and
wgrad_compute
:
if
ctx
.
sequence_parallel
and
wgrad_compute
:
handle
.
wait
()
handle
.
wait
()
...
@@ -772,39 +783,99 @@ def linear_rs(
...
@@ -772,39 +783,99 @@ def linear_rs(
linear_rs
.
warned
=
False
linear_rs
.
warned
=
False
def
column_parallel_linear_init_wrapper
(
fn
):
class
FluxColumnParallelLinear
(
ColumnParallelLinear
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# flux params
self
.
use_flux
=
False
if
"use_flux"
in
kwargs
:
self
.
use_flux
=
kwargs
[
"use_flux"
]
elif
hasattr
(
self
.
config
,
"use_flux"
):
self
.
use_flux
=
self
.
config
.
use_flux
self
.
flux_transpose_weight
=
False
if
"flux_transpose_weight"
in
kwargs
:
self
.
flux_transpose_weight
=
kwargs
[
"flux_transpose_weight"
]
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_ag_gemm_op
=
None
self
.
bw_gemm_rs_op
=
None
return
wrapper
class
ColumnParallelLinearPatch
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
its second dimension as A = [A_1, ..., A_p].
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
=
True
,
gather_output
=
False
,
stride
=
1
,
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
embedding_activation_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
disable_grad_reduce
:
bool
=
False
,
):
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
gather_output
=
gather_output
,
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
skip_bias_add
=
skip_bias_add
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
embedding_activation_buffer
=
embedding_activation_buffer
,
grad_output_buffer
=
grad_output_buffer
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
disable_grad_reduce
=
disable_grad_reduce
,
)
# flux params
self
.
_forward_impl
=
ag_linear
self
.
flux_transpose_weight
=
getattr
(
self
.
config
,
"flux_transpose_weight"
,
False
)
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_ag_gemm_op
=
None
self
.
bw_gemm_rs_op
=
None
def
forward
(
def
forward
(
self
,
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
...
@@ -867,14 +938,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
...
@@ -867,14 +938,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
):
):
self
.
embedding_activation_buffer
.
append
(
input_parallel
)
self
.
embedding_activation_buffer
.
append
(
input_parallel
)
# Matrix multiply.
# flux kernels.
if
self
.
use_flux
:
if
self
.
sequence_parallel
:
assert
HAS_FLUX
,
"flux is NOT installed"
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
output_hidden_size
=
weight
.
size
(
0
)
output_hidden_size
=
weight
.
size
(
0
)
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
if
self
.
sequence_parallel
:
current_flux_params
=
(
current_flux_params
=
(
sequence_len
,
sequence_len
,
batch_size
,
batch_size
,
...
@@ -913,32 +981,21 @@ class ColumnParallelLinearPatch(torch.nn.Module):
...
@@ -913,32 +981,21 @@ class ColumnParallelLinearPatch(torch.nn.Module):
self
.
previous_flux_params
=
current_flux_params
self
.
previous_flux_params
=
current_flux_params
self
.
_forward_impl
=
ag_linear
elif
not
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad
=
False
if
self
.
explicit_expert_comm
else
self
.
allreduce_dgrad
allreduce_dgrad
=
False
if
self
.
explicit_expert_comm
else
self
.
allreduce_dgrad
forward_params
=
{
output_parallel
=
self
.
_forward_impl
(
"input"
:
input_parallel
,
input
=
input_parallel
,
"weight"
:
weight
,
weight
=
weight
,
"bias"
:
bias
,
bias
=
bias
,
"gradient_accumulation_fusion"
:
self
.
gradient_accumulation_fusion
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
"allreduce_dgrad"
:
allreduce_dgrad
,
allreduce_dgrad
=
allreduce_dgrad
,
"sequence_parallel"
:
False
if
self
.
explicit_expert_comm
else
self
.
sequence_parallel
,
sequence_parallel
=
False
if
self
.
explicit_expert_comm
else
self
.
sequence_parallel
,
"grad_output_buffer"
:
self
.
grad_output_buffer
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
grad_output_buffer
=
self
.
grad_output_buffer
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
"wgrad_deferral_limit"
:
self
.
config
.
wgrad_deferral_limit
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
wgrad_deferral_limit
=
self
.
config
.
wgrad_deferral_limit
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
}
transpose_weight
=
self
.
flux_transpose_weight
,
if
self
.
use_flux
:
fw_ag_gemm_op
=
self
.
fw_ag_gemm_op
,
forward_params
.
update
({
bw_gemm_rs_op
=
self
.
bw_gemm_rs_op
"transpose_weight"
:
self
.
flux_transpose_weight
,
)
"fw_ag_gemm_op"
:
self
.
fw_ag_gemm_op
,
"bw_gemm_rs_op"
:
self
.
bw_gemm_rs_op
,
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
gather_output
=
self
.
gather_output
gather_output
=
self
.
gather_output
# Use the runtime gather output if it's set explicitly.
# Use the runtime gather output if it's set explicitly.
...
@@ -955,38 +1012,79 @@ class ColumnParallelLinearPatch(torch.nn.Module):
...
@@ -955,38 +1012,79 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return
output
,
output_bias
return
output
,
output_bias
def
row_parallel_linear_init_wrapper
(
fn
):
class
FluxRowParallelLinear
(
torch
.
nn
.
Module
):
@
wraps
(
fn
)
"""Linear layer with row parallelism.
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# flux params
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
self
.
use_flux
=
False
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
if
"use_flux"
in
kwargs
:
self
.
use_flux
=
kwargs
[
"use_flux"
]
elif
hasattr
(
self
.
config
,
"use_flux"
):
self
.
use_flux
=
self
.
config
.
use_flux
self
.
flux_transpose_weight
=
False
if
"flux_transpose_weight"
in
kwargs
:
self
.
flux_transpose_weight
=
kwargs
[
"flux_transpose_weight"
]
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
self
.
previous_flux_params
=
(
None
,)
*
5
Args:
self
.
fw_gemm_rs_op
=
None
input_size:
self
.
bw_ag_gemm_op
=
None
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object
return
wrapper
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
input_is_parallel
:
bool
,
skip_bias_add
:
bool
,
stride
:
int
=
1
,
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
):
class
RowParallelLinearPatch
(
torch
.
nn
.
Module
):
super
(
FluxRowParallelLinear
,
self
)
__init__
(
"""Linear layer with row parallelism.
input_size
=
input_size
,
output_size
=
output_size
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
input_is_parallel
=
input_is_parallel
,
skip_bias_add
=
skip_bias_add
,
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
)
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
# flux params
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
self
.
_forward_impl
=
linear_rs
self
.
flux_transpose_weight
=
getattr
(
self
.
config
,
"flux_transpose_weight"
,
False
)
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_gemm_rs_op
=
None
self
.
bw_ag_gemm_op
=
None
"""
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
"""Forward of RowParallelLinear
"""Forward of RowParallelLinear
...
@@ -1011,14 +1109,14 @@ class RowParallelLinearPatch(torch.nn.Module):
...
@@ -1011,14 +1109,14 @@ class RowParallelLinearPatch(torch.nn.Module):
else
:
else
:
assert
not
self
.
sequence_parallel
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
if
self
.
use_flux
:
assert
HAS_FLUX
,
"flux is NOT installed"
# flux kernels
if
self
.
sequence_parallel
:
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
output_hidden_size
=
self
.
weight
.
size
(
0
)
output_hidden_size
=
self
.
weight
.
size
(
0
)
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
if
self
.
sequence_parallel
:
current_flux_params
=
(
current_flux_params
=
(
sequence_len
,
sequence_len
,
batch_size
,
batch_size
,
...
@@ -1057,47 +1155,31 @@ class RowParallelLinearPatch(torch.nn.Module):
...
@@ -1057,47 +1155,31 @@ class RowParallelLinearPatch(torch.nn.Module):
self
.
previous_flux_params
=
current_flux_params
self
.
previous_flux_params
=
current_flux_params
self
.
_forward_impl
=
linear_rs
output_parallel
=
self
.
_forward_impl
(
elif
not
self
.
weight
.
requires_grad
:
input
=
input_parallel
,
self
.
_forward_impl
=
linear_with_frozen_weight
weight
=
self
.
weight
,
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
and
self
.
sequence_parallel
else
None
,
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
allreduce_dgrad
=
False
,
allreduce_dgrad
=
False
sequence_parallel
=
False
if
explicit_expert_comm
else
self
.
sequence_parallel
,
grad_output_buffer
=
None
,
forward_params
=
{
transpose_weight
=
self
.
flux_transpose_weight
,
"input"
:
input_parallel
,
fw_gemm_rs_op
=
self
.
fw_gemm_rs_op
,
"weight"
:
self
.
weight
,
bw_ag_gemm_op
=
self
.
bw_ag_gemm_op
"bias"
:
self
.
bias
if
self
.
use_flux
or
not
self
.
skip_bias_add
else
None
,
)
"gradient_accumulation_fusion"
:
self
.
gradient_accumulation_fusion
,
"allreduce_dgrad"
:
allreduce_dgrad
,
"sequence_parallel"
:
False
if
not
self
.
use_flux
else
self
.
sequence_parallel
,
"grad_output_buffer"
:
None
,
}
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
,
"fw_gemm_rs_op"
:
self
.
fw_gemm_rs_op
,
"bw_ag_gemm_op"
:
self
.
bw_ag_gemm_op
,
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
if
self
.
use_flux
:
return
output_parallel
,
None
if
not
self
.
skip_bias_add
else
self
.
bias
# All-reduce across all the partitions.
if
self
.
explicit_expert_comm
:
if
self
.
explicit_expert_comm
:
assert
self
.
skip_bias_add
assert
self
.
skip_bias_add
output_
=
output_parallel
output_
=
output_parallel
elif
self
.
sequence_parallel
:
elif
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
output_
=
output_parallel
else
:
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
(
output_
+
self
.
bias
)
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
output_bias
=
None
if
not
self
.
sequence_parallel
:
output
=
(
output_
+
self
.
bias
)
if
self
.
bias
is
not
None
else
output_
else
:
else
:
output
=
output_
output
=
output_
output_bias
=
self
.
bias
output_bias
=
self
.
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment