Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
8ec8fb6b
Commit
8ec8fb6b
authored
Apr 15, 2025
by
dongcl
Browse files
move flux kernels outside
parent
5da71bf3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
49 deletions
+95
-49
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+95
-49
No files found.
dcu_megatron/core/tensor_parallel/layers.py
View file @
8ec8fb6b
...
@@ -479,6 +479,8 @@ class LinearRS(torch.autograd.Function):
...
@@ -479,6 +479,8 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer
,
grad_output_buffer
,
wgrad_deferral_limit
,
wgrad_deferral_limit
,
transpose_weight
=
False
,
transpose_weight
=
False
,
fw_gemm_rs_op
=
None
,
bw_ag_gemm_op
=
None
):
):
"""Forward."""
"""Forward."""
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
...
@@ -489,6 +491,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -489,6 +491,7 @@ class LinearRS(torch.autograd.Function):
ctx
.
wgrad_deferral_limit
=
wgrad_deferral_limit
ctx
.
wgrad_deferral_limit
=
wgrad_deferral_limit
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
transpose_weight
=
transpose_weight
ctx
.
transpose_weight
=
transpose_weight
ctx
.
bw_ag_gemm_op
=
bw_ag_gemm_op
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
...
@@ -498,7 +501,8 @@ class LinearRS(torch.autograd.Function):
...
@@ -498,7 +501,8 @@ class LinearRS(torch.autograd.Function):
output_hidden_size
=
weight
.
size
(
0
)
output_hidden_size
=
weight
.
size
(
0
)
if
sequence_parallel
:
if
sequence_parallel
:
gemm_rs_op
=
flux
.
GemmRS
(
if
fw_gemm_rs_op
is
None
:
fw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
#world_size // torch.cuda.device_count(),
1
,
#world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
sequence_len
*
batch_size
,
...
@@ -508,7 +512,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -508,7 +512,7 @@ class LinearRS(torch.autograd.Function):
transpose_weight
=
transpose_weight
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
,
fuse_reduction
=
False
,
)
)
output
=
gemm_rs_op
.
forward
(
output
=
fw_
gemm_rs_op
.
forward
(
input
,
input
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
bias
=
bias
,
...
@@ -519,29 +523,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -519,29 +523,7 @@ class LinearRS(torch.autograd.Function):
)
)
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
else
:
else
:
output_buf
=
torch
.
empty
(
output
=
torch
.
matmul
(
input
,
weight
.
t
())
[
sequence_len
*
batch_size
,
output_hidden_size
],
dtype
=
input
.
dtype
,
device
=
input
.
device
,
requires_grad
=
False
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
use_fp8_gemm
=
False
,
)
output
=
gemm_only_op
.
forward
(
input
,
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
output_buf
=
output_buf
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
output
=
output
.
view
(
sequence_len
,
batch_size
,
-
1
)
output
=
_reduce
(
output
)
output
=
_reduce
(
output
)
# torch.cuda.current_stream().synchronize()
# torch.cuda.current_stream().synchronize()
...
@@ -556,6 +538,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -556,6 +538,7 @@ class LinearRS(torch.autograd.Function):
grad_output_buffer
=
ctx
.
grad_output_buffer
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
transpose_weight
=
ctx
.
transpose_weight
transpose_weight
=
ctx
.
transpose_weight
bw_ag_gemm_op
=
ctx
.
bw_ag_gemm_op
wgrad_compute
=
True
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
if
grad_output_buffer
is
not
None
:
...
@@ -587,7 +570,8 @@ class LinearRS(torch.autograd.Function):
...
@@ -587,7 +570,8 @@ class LinearRS(torch.autograd.Function):
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
input_hidden_size
=
weight
.
size
(
-
1
)
input_hidden_size
=
weight
.
size
(
-
1
)
ag_kernel
=
flux
.
AGKernel
(
if
bw_gemm_rs_op
is
None
:
bw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_group
(),
1
,
#world_size // torch.cuda.device_count(),
1
,
#world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
*
world_size
,
sequence_len
*
batch_size
*
world_size
,
...
@@ -599,7 +583,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -599,7 +583,7 @@ class LinearRS(torch.autograd.Function):
local_copy
=
False
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
)
grad_input
=
ag_kernel
.
forward
(
grad_input
=
bw_ag_gemm_op
.
forward
(
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
bias
=
None
,
bias
=
None
,
...
@@ -662,7 +646,7 @@ class LinearRS(torch.autograd.Function):
...
@@ -662,7 +646,7 @@ class LinearRS(torch.autograd.Function):
grad_weight
=
total_grad_output
.
t
().
matmul
(
total_input
)
grad_weight
=
total_grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
total_grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
total_grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
linear_rs
(
def
linear_rs
(
...
@@ -675,6 +659,8 @@ def linear_rs(
...
@@ -675,6 +659,8 @@ def linear_rs(
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
wgrad_deferral_limit
:
Optional
[
int
]
=
0
,
wgrad_deferral_limit
:
Optional
[
int
]
=
0
,
transpose_weight
:
Optional
[
bool
]
=
False
,
transpose_weight
:
Optional
[
bool
]
=
False
,
fw_gemm_rs_op
=
None
,
bw_ag_gemm_op
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
gradient accumulation fusion in backprop.
...
@@ -736,6 +722,11 @@ def linear_rs(
...
@@ -736,6 +722,11 @@ def linear_rs(
deferred. Disable by setting this to 0. Defaults to 0.
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
transpose_weight: transpose weight.
fw_gemm_rs_op: flux AGKernel for forward.
bw_ag_gemm_op: flux GemmRS for backward.
"""
"""
args
=
[
args
=
[
...
@@ -748,6 +739,8 @@ def linear_rs(
...
@@ -748,6 +739,8 @@ def linear_rs(
grad_output_buffer
,
grad_output_buffer
,
wgrad_deferral_limit
,
wgrad_deferral_limit
,
transpose_weight
,
transpose_weight
,
fw_gemm_rs_op
,
bw_ag_gemm_op
,
]
]
if
not
linear_rs
.
warned
:
if
not
linear_rs
.
warned
:
...
@@ -976,6 +969,11 @@ def row_parallel_linear_init_wrapper(fn):
...
@@ -976,6 +969,11 @@ def row_parallel_linear_init_wrapper(fn):
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
if
self
.
sequence_parallel
:
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_gemm_rs_op
=
None
self
.
bw_ag_gemm_op
=
None
return
wrapper
return
wrapper
...
@@ -1012,6 +1010,50 @@ class RowParallelLinearPatch(torch.nn.Module):
...
@@ -1012,6 +1010,50 @@ class RowParallelLinearPatch(torch.nn.Module):
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
if
self
.
use_flux
:
if
self
.
use_flux
:
assert
HAS_FLUX
,
"flux is NOT installed"
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
output_hidden_size
=
weight
.
size
(
0
)
world_size
=
get_tensor_model_parallel_world_size
()
if
self
.
sequence_parallel
:
current_flux_params
=
(
sequence_len
,
batch_size
,
input_hidden_size
,
output_hidden_size
,
input_parallel
.
dtype
)
if
(
self
.
fw_gemm_rs_op
is
None
or
current_flux_params
!=
self
.
previous_flux_params
):
self
.
fw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
output_hidden_size
,
input_parallel
.
dtype
,
input_parallel
.
dtype
,
transpose_weight
=
self
.
flux_transpose_weight
,
fuse_reduction
=
False
)
self
.
bw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len
*
batch_size
,
input_hidden_size
,
output_hidden_size
,
input_parallel
.
dtype
,
output_dtype
=
input_parallel
.
dtype
,
transpose_weight
=
self
.
flux_transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
self
.
previous_flux_params
=
current_flux_params
self
.
_forward_impl
=
linear_rs
self
.
_forward_impl
=
linear_rs
elif
not
self
.
weight
.
requires_grad
:
elif
not
self
.
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
self
.
_forward_impl
=
linear_with_frozen_weight
...
@@ -1031,7 +1073,11 @@ class RowParallelLinearPatch(torch.nn.Module):
...
@@ -1031,7 +1073,11 @@ class RowParallelLinearPatch(torch.nn.Module):
}
}
if
self
.
use_flux
:
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
})
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
,
"fw_gemm_rs_op"
:
self
.
fw_gemm_rs_op
,
"bw_ag_gemm_op"
:
self
.
bw_ag_gemm_op
,
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment