Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
595e428a
Commit
595e428a
authored
Apr 18, 2025
by
dongcl
Browse files
use flux gemmrs
parent
0867fd90
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
70 additions
and
51 deletions
+70
-51
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+1
-1
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+2
-1
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+64
-42
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+1
-0
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+0
-3
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+1
-3
pretrain_gpt.py
pretrain_gpt.py
+1
-1
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
595e428a
...
@@ -187,7 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -187,7 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# flux
# flux
if
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
0
):
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)
):
from
..core.tensor_parallel
import
(
from
..core.tensor_parallel
import
(
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
FluxRowParallelLinear
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
595e428a
...
@@ -16,6 +16,7 @@ from megatron.core.packed_seq_params import PackedSeqParams
...
@@ -16,6 +16,7 @@ from megatron.core.packed_seq_params import PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
...
@@ -40,7 +41,7 @@ def gpt_model_init_wrapper(fn):
...
@@ -40,7 +41,7 @@ def gpt_model_init_wrapper(fn):
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
[
MultiTokenPredictor
(
MultiTokenPredictor
(
config
,
self
.
config
,
self
.
mtp_spec
.
submodules
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
max_sequence_length
=
self
.
max_sequence_length
,
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
595e428a
import
os
import
os
import
socket
import
warnings
import
warnings
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
...
@@ -160,6 +161,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
...
@@ -160,6 +161,19 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
return
output
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
"""
if
group
is
None
:
group
=
get_tensor_model_parallel_group
()
hostname
=
socket
.
gethostname
()
hostnames
=
[
None
]
*
get_tensor_model_parallel_world_size
()
torch
.
distributed
.
all_gather_object
(
hostnames
,
hostname
,
group
=
group
)
num_nodes
=
len
(
set
(
hostnames
))
return
num_nodes
class
AGLinear
(
torch
.
autograd
.
Function
):
class
AGLinear
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
...
@@ -196,7 +210,7 @@ class AGLinear(torch.autograd.Function):
...
@@ -196,7 +210,7 @@ class AGLinear(torch.autograd.Function):
if
fw_ag_gemm_op
is
None
:
if
fw_ag_gemm_op
is
None
:
fw_ag_gemm_op
=
flux
.
AGKernel
(
fw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
*
world_size
,
sequence_len
*
batch_size
*
world_size
,
output_hidden_size
,
output_hidden_size
,
input_hidden_size
,
input_hidden_size
,
...
@@ -265,34 +279,31 @@ class AGLinear(torch.autograd.Function):
...
@@ -265,34 +279,31 @@ class AGLinear(torch.autograd.Function):
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
# if bw_gemm_rs_op is None:
if
bw_gemm_rs_op
is
None
:
# input_hidden_size = weight.size(-1)
input_hidden_size
=
weight
.
size
(
-
1
)
# bw_gemm_rs_op = flux.GemmRS(
bw_gemm_rs_op
=
flux
.
GemmRS
(
# get_tensor_model_parallel_group(),
get_tensor_model_parallel_group
(),
# 1, # world_size // torch.cuda.device_count(),
get_tensor_model_parallel_node_size
(),
# sequence_len * batch_size,
sequence_len
*
batch_size
,
# input_hidden_size,
input_hidden_size
,
# input.dtype,
input
.
dtype
,
# input.dtype,
input
.
dtype
,
# transpose_weight=transpose_weight,
transpose_weight
=
transpose_weight
,
# fuse_reduction=False
fuse_reduction
=
False
# )
)
# grad_input = bw_gemm_rs_op.forward(
grad_input
=
bw_gemm_rs_op
.
forward
(
# grad_output.view(sequence_len * batch_size, -1),
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
# weight if transpose_weight else weight.t().contiguous(),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
# bias=None,
bias
=
None
,
# input_scale=None,
input_scale
=
None
,
# weight_scale=None,
weight_scale
=
None
,
# output_scale=None,
output_scale
=
None
,
# fast_accum=False
fast_accum
=
False
# )
)
# torch.distributed.barrier()
torch
.
cuda
.
current_stream
().
synchronize
()
# torch.cuda.current_stream().synchronize()
grad_input
=
grad_input
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
_reduce_scatter_along_first_dim
(
grad_input
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
@@ -514,7 +525,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -514,7 +525,7 @@ class LinearRS(torch.autograd.Function):
if
fw_gemm_rs_op
is
None
:
if
fw_gemm_rs_op
is
None
:
fw_gemm_rs_op
=
flux
.
GemmRS
(
fw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
#world_size // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
,
sequence_len
*
batch_size
,
output_hidden_size
,
output_hidden_size
,
input
.
dtype
,
input
.
dtype
,
...
@@ -522,6 +533,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -522,6 +533,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight
=
transpose_weight
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
,
fuse_reduction
=
False
,
)
)
output
=
fw_gemm_rs_op
.
forward
(
output
=
fw_gemm_rs_op
.
forward
(
input
.
view
(
sequence_len
*
batch_size
,
-
1
),
input
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
...
@@ -531,12 +543,8 @@ class LinearRS(torch.autograd.Function):
...
@@ -531,12 +543,8 @@ class LinearRS(torch.autograd.Function):
output_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
fast_accum
=
False
,
)
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else
:
else
:
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
torch
.
matmul
(
input
,
weight
.
t
())
...
@@ -586,7 +594,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -586,7 +594,7 @@ class LinearRS(torch.autograd.Function):
if
bw_ag_gemm_op
is
None
:
if
bw_ag_gemm_op
is
None
:
bw_ag_gemm_op
=
flux
.
AGKernel
(
bw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
#world_size // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
*
world_size
,
sequence_len
*
batch_size
*
world_size
,
input_hidden_size
,
input_hidden_size
,
output_hidden_size
,
output_hidden_size
,
...
@@ -605,10 +613,8 @@ class LinearRS(torch.autograd.Function):
...
@@ -605,10 +613,8 @@ class LinearRS(torch.autograd.Function):
output_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
fast_accum
=
False
,
)
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
contiguous
().
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
grad_input
=
grad_input
.
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
@@ -957,7 +963,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -957,7 +963,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
):
):
self
.
fw_ag_gemm_op
=
flux
.
AGKernel
(
self
.
fw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
*
world_size
,
sequence_len
*
batch_size
*
world_size
,
output_hidden_size
,
output_hidden_size
,
input_hidden_size
,
input_hidden_size
,
...
@@ -970,7 +976,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -970,7 +976,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
self
.
bw_gemm_rs_op
=
flux
.
GemmRS
(
self
.
bw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
*
world_size
,
sequence_len
*
batch_size
*
world_size
,
input_hidden_size
,
input_hidden_size
,
input_parallel
.
dtype
,
input_parallel
.
dtype
,
...
@@ -1011,6 +1017,14 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -1011,6 +1017,14 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
def
__repr__
(
self
):
tp
=
self
.
output_size
//
self
.
output_size_per_partition
use_bias
=
self
.
bias
is
not
None
and
self
.
bias
is
True
return
(
f
"
{
type
(
self
).
__name__
}
(in_features=
{
self
.
input_size
}
, "
f
"out_features=
{
self
.
output_size_per_partition
}
, bias=
{
use_bias
}
, TP=
{
tp
}
)"
)
class
FluxRowParallelLinear
(
RowParallelLinear
):
class
FluxRowParallelLinear
(
RowParallelLinear
):
"""Linear layer with row parallelism.
"""Linear layer with row parallelism.
...
@@ -1131,7 +1145,7 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -1131,7 +1145,7 @@ class FluxRowParallelLinear(RowParallelLinear):
):
):
self
.
fw_gemm_rs_op
=
flux
.
GemmRS
(
self
.
fw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
,
sequence_len
*
batch_size
,
output_hidden_size
,
output_hidden_size
,
input_parallel
.
dtype
,
input_parallel
.
dtype
,
...
@@ -1142,7 +1156,7 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -1142,7 +1156,7 @@ class FluxRowParallelLinear(RowParallelLinear):
self
.
bw_ag_gemm_op
=
flux
.
AGKernel
(
self
.
bw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count
(),
get_tensor_model_parallel_node_size
(),
sequence_len
*
batch_size
,
sequence_len
*
batch_size
,
input_hidden_size
,
input_hidden_size
,
output_hidden_size
,
output_hidden_size
,
...
@@ -1184,3 +1198,11 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -1184,3 +1198,11 @@ class FluxRowParallelLinear(RowParallelLinear):
output
=
output_
output
=
output_
output_bias
=
self
.
bias
output_bias
=
self
.
bias
return
output
,
output_bias
return
output
,
output_bias
def
__repr__
(
self
):
tp
=
self
.
input_size
//
self
.
input_size_per_partition
use_bias
=
self
.
bias
is
not
None
and
self
.
bias
is
True
return
(
f
"
{
type
(
self
).
__name__
}
(in_features=
{
self
.
input_size_per_partition
}
, "
f
"out_features=
{
self
.
output_size
}
, bias=
{
use_bias
}
, TP=
{
tp
}
)"
)
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
View file @
595e428a
...
@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang
...
@@ -11,6 +11,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import Lang
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
megatron.core.fusions.fused_cross_entropy
import
fused_vocab_parallel_cross_entropy
from
megatron.core.fusions.fused_cross_entropy
import
fused_vocab_parallel_cross_entropy
from
megatron.core.transformer
import
ModuleSpec
,
TransformerConfig
,
build_module
from
megatron.core.transformer
import
ModuleSpec
,
TransformerConfig
,
build_module
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
595e428a
...
@@ -26,9 +26,6 @@ class ExtraTransformerConfig:
...
@@ -26,9 +26,6 @@ class ExtraTransformerConfig:
##################
##################
# flux
# flux
##################
##################
use_flux
:
bool
=
False
"""If set, flux will be used in ColumnParallelLinear and RowParallelLinear"""
flux_transpose_weight
:
bool
=
False
flux_transpose_weight
:
bool
=
False
...
...
dcu_megatron/training/arguments.py
View file @
595e428a
...
@@ -182,9 +182,7 @@ def _add_mtp_args(parser):
...
@@ -182,9 +182,7 @@ def _add_mtp_args(parser):
def
_add_flux_args
(
parser
):
def
_add_flux_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
=
parser
.
add_argument_group
(
title
=
'flux args'
)
group
.
add_argument
(
'--use-flux'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, flux will be used in ColumnParallelLinear and RowParallelLinear'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to transpose weight when using flux kernel'
)
help
=
'Whether to transpose weight when using flux kernel'
)
return
parser
return
parser
pretrain_gpt.py
View file @
595e428a
...
@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
"""
args
=
get_args
()
args
=
get_args
()
use_te
=
args
.
transformer_impl
==
"transformer_engine"
or
bool
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
0
))
use_te
=
args
.
transformer_impl
==
"transformer_engine"
or
bool
(
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)
))
if
args
.
record_memory_history
:
if
args
.
record_memory_history
:
torch
.
cuda
.
memory
.
_record_memory_history
(
True
,
torch
.
cuda
.
memory
.
_record_memory_history
(
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment