Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
f25c421e
Commit
f25c421e
authored
Apr 14, 2025
by
dongcl
Browse files
fix flux bug
parent
b6eb1484
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
83 deletions
+97
-83
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+8
-9
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+89
-74
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
f25c421e
...
...
@@ -143,9 +143,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
...
...
@@ -197,18 +197,17 @@ class CoreAdaptation(MegatronAdaptationABC):
HAS_FLUX
=
False
if
HAS_FLUX
:
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward"
,
ColumnParallelLinearPatch
.
forward
)
#
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
#
parallel_linear_init_wrapper,
#
apply_wrapper=True)
#
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
#
ColumnParallelLinearPatch.forward)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
f25c421e
...
...
@@ -13,8 +13,11 @@ import torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
...
...
@@ -31,12 +34,15 @@ from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel.layers
import
(
custom_fwd
,
custom_bwd
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
...
...
@@ -176,26 +182,24 @@ class AGLinear(torch.autograd.Function):
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
transpose_weight
=
transpose_weight
sequence_len
=
input
.
size
(
0
)
sequence_len
,
batch_size
,
input_hidden_size
=
input
.
size
()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
]
sequence_len
*
batch_size
,
input_hidden_size
)
M
,
K
=
list
(
input
.
size
())
N
=
weight
.
size
(
0
)
M
=
M
*
get_tensor_model_parallel_world_size
()
output_hidden_size
=
weight
.
size
(
0
)
if
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
if
sequence_parallel
:
sequence_len
=
sequence_len
*
get_tensor_model_parallel_world_size
()
ag_gemm_kernel
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel
_world_size
()
//
torch
.
cuda
.
device_count
(),
M
,
N
,
K
,
1
,
# torch.distributed.get
_world_size() // torch.cuda.device_count(),
sequence_len
*
batch_size
,
output_hidden_size
,
input_hidden_size
,
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
...
...
@@ -206,13 +210,13 @@ class AGLinear(torch.autograd.Function):
input
,
weight
,
bias
=
bias
,
input_scale
=
input_scal
e
,
weight_scale
=
weight_scal
e
,
input_scale
=
Non
e
,
weight_scale
=
Non
e
,
output_scale
=
None
,
fast_accum
=
False
)
else
:
output_buf
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output_buf
=
torch
.
empty
([
sequence_len
*
batch_size
,
output_hidden_size
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
...
...
@@ -231,7 +235,7 @@ class AGLinear(torch.autograd.Function):
)
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
,
input
.
size
(
0
)
//
sequence_len
,
-
1
)
output
=
output
.
view
(
sequence_len
,
batch_size
,
-
1
)
return
output
...
...
@@ -272,20 +276,17 @@ class AGLinear(torch.autograd.Function):
if
ctx
.
sequence_parallel
:
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
input_hidden_size
=
weight
.
size
(
-
1
)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output
=
grad_output
.
view
(
sequence_len
*
batch_size
,
output_hidden_size
)
if
not
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
1
,
#
world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
out
put_hidden_size
,
in
put_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
...
...
@@ -293,7 +294,7 @@ class AGLinear(torch.autograd.Function):
)
grad_input
=
gemm_rs_op
.
forward
(
grad_output
,
weight
,
weight
if
transpose_weight
else
weight
.
t
().
contiguous
()
,
bias
=
None
,
input_scale
=
None
,
weight_scale
=
None
,
...
...
@@ -302,13 +303,16 @@ class AGLinear(torch.autograd.Function):
)
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
view
(
sequence_len
//
get_tensor_model_parallel_
group
(),
batch_size
,
-
1
)
grad_input
=
grad_input
.
view
(
sequence_len
//
get_tensor_model_parallel_
world_size
(),
batch_size
,
-
1
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
and
wgrad_compute
:
handle
.
wait
()
if
ctx
.
sequence_parallel
:
grad_output
=
grad_output
.
view
(
sequence_len
,
batch_size
,
output_hidden_size
)
if
wgrad_compute
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
total_input
...
...
@@ -503,25 +507,17 @@ class LinearRS(torch.autograd.Function):
world_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
input
.
dim
()
sequence_len
=
input
.
size
(
0
)
sequence_len
,
batch_size
,
_
=
input
.
size
()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
]
)
M
=
input
.
size
(
0
)
N
=
weight
.
size
(
0
)
input
=
input
.
view
(
sequence_len
*
batch_size
,
-
1
)
output_hidden_size
=
weight
.
size
(
0
)
if
sequence_parallel
:
if
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
M
,
N
,
1
,
#
world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
output_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
...
...
@@ -529,15 +525,23 @@ class LinearRS(torch.autograd.Function):
)
output
=
gemm_rs_op
.
forward
(
input
,
weight
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else
:
output
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output_buf
=
torch
.
empty
(
[
sequence_len
*
batch_size
,
output_hidden_size
],
dtype
=
input
.
dtype
,
device
=
input
.
device
,
requires_grad
=
False
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
...
...
@@ -546,21 +550,18 @@ class LinearRS(torch.autograd.Function):
)
output
=
gemm_only_op
.
forward
(
input
,
weight
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
output_buf
=
output
,
output_buf
=
output
_buf
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
output
=
output
.
view
(
sequence_len
,
batch_size
,
-
1
)
output
=
_reduce
(
output
)
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
,
input
.
size
(
0
)
//
sequence_len
,
-
1
)
if
not
sequence_parallel
:
_reduce
(
output
)
# torch.cuda.current_stream().synchronize()
return
output
@
staticmethod
...
...
@@ -579,37 +580,45 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer
.
append
(
grad_output
)
wgrad_compute
=
False
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
if
wgrad
:
if
ctx
.
sequence_parallel
dim_size
=
list
(
grad_output
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
grad_output
=
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
)
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
grad_output
.
dtype
,
"mpu"
)
handle
=
dist_all_gather_func
(
all_gather_buffer
,
grad_output
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
M
,
K
=
list
(
grad_output
.
size
())
M
=
M
*
world_size
N
=
weight
.
size
(
-
1
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_grad_output
=
all_gather_buffer
else
:
total_grad_output
=
grad_output
if
not
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
grad_input
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
input_hidden_size
=
weight
.
size
(
-
1
)
ag_kernel
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
M
,
N
,
K
,
in
put
.
dtype
,
1
,
#
world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
*
world_size
,
input_hidden_size
,
output_hidden_size
,
grad_out
put
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
output
=
ag_kernel
.
forward
(
grad_output
,
weight
,
grad_input
=
ag_kernel
.
forward
(
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
bias
=
None
,
input_scale
=
None
,
weight_scale
=
None
,
...
...
@@ -617,24 +626,29 @@ class LinearRS(torch.autograd.Function):
fast_accum
=
False
,
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
contiguous
().
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
and
wgrad_compute
:
handle
.
wait
()
if
wgrad_compute
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
input
total_
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
total_
grad_output
,
input
)
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
total_input
,
total_
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
total_input
,
total_
grad_output
,
weight
.
main_grad
)
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
...
...
@@ -662,8 +676,8 @@ class LinearRS(torch.autograd.Function):
else
:
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_weight
=
total_
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
total_
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
...
...
@@ -952,19 +966,20 @@ class RowParallelLinearPatch(torch.nn.Module):
forward_params
=
{
"input"
:
input_parallel
,
"weight"
:
self
.
weight
,
"bias"
:
None
if
not
self
.
use_flux
or
self
.
skip_bias_add
else
self
.
bias
,
"bias"
:
self
.
bias
if
self
.
use_flux
or
not
self
.
skip_bias_add
else
None
,
"gradient_accumulation_fusion"
:
self
.
gradient_accumulation_fusion
,
"allreduce_dgrad"
:
allreduce_dgrad
,
"sequence_parallel"
:
False
if
not
self
.
use_flux
else
self
.
sequence_parallel
,
"grad_output_buffer"
:
Fals
e
,
"grad_output_buffer"
:
Non
e
,
}
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
if
self
.
use_flux
:
return
output_parallel
,
None
if
skip_bias_add
else
self
.
bias
return
output_parallel
,
None
if
not
self
.
skip_bias_add
else
self
.
bias
# All-reduce across all the partitions.
if
self
.
explicit_expert_comm
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment