Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
138b70a2
Commit
138b70a2
authored
Apr 17, 2025
by
dongcl
Browse files
fix flux bug
parent
72aeb0f3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
41 deletions
+44
-41
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+2
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+42
-39
No files found.
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
138b70a2
...
...
@@ -91,7 +91,7 @@ def get_gpt_layer_with_flux_spec(
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
pre_mlp_layernorm
=
TENorm
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
...
...
@@ -119,7 +119,7 @@ def get_gpt_layer_with_flux_spec(
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
pre_mlp_layernorm
=
TENorm
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
138b70a2
...
...
@@ -213,12 +213,12 @@ class AGLinear(torch.autograd.Function):
output_scale
=
None
,
fast_accum
=
False
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
else
:
output
=
torch
.
matmul
(
input
,
weight
.
t
())
torch
.
cuda
.
current_stream
().
synchronize
()
return
output
@
staticmethod
...
...
@@ -260,31 +260,34 @@ class AGLinear(torch.autograd.Function):
if
ctx
.
sequence_parallel
:
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
if
bw_gemm_rs_op
is
None
:
input_hidden_size
=
weight
.
size
(
-
1
)
bw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
input_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
)
grad_input
=
bw_gemm_rs_op
.
forward
(
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
bias
=
None
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
)
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# if bw_gemm_rs_op is None:
# input_hidden_size = weight.size(-1)
# bw_gemm_rs_op = flux.GemmRS(
# get_tensor_model_parallel_group(),
# 1, # world_size // torch.cuda.device_count(),
# sequence_len * batch_size,
# input_hidden_size,
# input.dtype,
# input.dtype,
# transpose_weight=transpose_weight,
# fuse_reduction=False
# )
# grad_input = bw_gemm_rs_op.forward(
# grad_output.view(sequence_len * batch_size, -1),
# weight if transpose_weight else weight.t().contiguous(),
# bias=None,
# input_scale=None,
# weight_scale=None,
# output_scale=None,
# fast_accum=False
# )
# torch.distributed.barrier()
# torch.cuda.current_stream().synchronize()
# grad_input = grad_input.view(sequence_len // world_size, batch_size, -1)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
_reduce_scatter_along_first_dim
(
grad_input
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
...
...
@@ -496,8 +499,6 @@ class LinearRS(torch.autograd.Function):
world_size
=
get_tensor_model_parallel_world_size
()
sequence_len
,
batch_size
,
_
=
input
.
size
()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
sequence_len
*
batch_size
,
-
1
)
output_hidden_size
=
weight
.
size
(
0
)
if
sequence_parallel
:
...
...
@@ -513,7 +514,7 @@ class LinearRS(torch.autograd.Function):
fuse_reduction
=
False
,
)
output
=
fw_gemm_rs_op
.
forward
(
input
,
input
.
view
(
sequence_len
*
batch_size
,
-
1
)
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
input_scale
=
None
,
...
...
@@ -521,12 +522,16 @@ class LinearRS(torch.autograd.Function):
output_scale
=
None
,
fast_accum
=
False
,
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# output = torch.matmul(input, weight.t())
# output = _reduce_scatter_along_first_dim(output)
else
:
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
_reduce
(
output
)
# torch.cuda.current_stream().synchronize()
return
output
@
staticmethod
...
...
@@ -785,10 +790,9 @@ def column_parallel_linear_init_wrapper(fn):
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
if
self
.
sequence_parallel
:
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_ag_gemm_op
=
None
self
.
bw_gemm_rs_op
=
None
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_ag_gemm_op
=
None
self
.
bw_gemm_rs_op
=
None
return
wrapper
...
...
@@ -969,10 +973,9 @@ def row_parallel_linear_init_wrapper(fn):
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
if
self
.
sequence_parallel
:
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_gemm_rs_op
=
None
self
.
bw_ag_gemm_op
=
None
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_gemm_rs_op
=
None
self
.
bw_ag_gemm_op
=
None
return
wrapper
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment