Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
c0d23a67
Commit
c0d23a67
authored
Apr 15, 2025
by
dongcl
Browse files
move flux kernels outside
parent
2d1ebf8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
141 additions
and
77 deletions
+141
-77
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+8
-3
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+133
-74
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
c0d23a67
...
...
@@ -168,7 +168,12 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel
import
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
parallel_linear_init_wrapper
from
..core.tensor_parallel
import
(
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
column_parallel_linear_init_wrapper
,
row_parallel_linear_init_wrapper
)
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
...
...
@@ -198,12 +203,12 @@ class CoreAdaptation(MegatronAdaptationABC):
if
HAS_FLUX
:
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
# parallel_linear_init_wrapper,
#
column_
parallel_linear_init_wrapper,
# apply_wrapper=True)
# MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
# ColumnParallelLinearPatch.forward)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
row_
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
c0d23a67
...
...
@@ -5,9 +5,9 @@ from typing import Callable, List, Optional
try
:
import
flux
HAS_FLUX
=
True
except
ImportError
:
from
megatron.training
import
print_rank_0
print_rank_0
(
f
"flux is NOT installed"
)
HAS_FLUX
=
False
import
torch
import
torch.nn.functional
as
F
...
...
@@ -171,6 +171,8 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
=
False
,
fw_ag_gemm_op
=
None
,
bw_gemm_rs_op
=
None
,
):
"""Forward."""
ctx
.
save_for_backward
(
input
,
weight
)
...
...
@@ -181,61 +183,41 @@ class AGLinear(torch.autograd.Function):
ctx
.
wgrad_deferral_limit
=
wgrad_deferral_limit
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
transpose_weight
=
transpose_weight
ctx
.
bw_gemm_rs_op
=
bw_gemm_rs_op
sequence_len
,
batch_size
,
input_hidden_size
=
input
.
size
()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
sequence_len
*
batch_size
,
input_hidden_size
)
output_hidden_size
=
weight
.
size
(
0
)
if
sequence_parallel
:
sequence_len
,
batch_size
,
input_hidden_size
=
input
.
size
()
output_hidden_size
=
weight
.
size
(
0
)
world_size
=
get_tensor_model_parallel_world_size
()
if
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
if
fw_ag_gemm_op
is
None
:
fw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len
*
batch_size
*
world_size
,
output_hidden_size
,
input_hidden_size
,
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
if
sequence_parallel
:
sequence_len
=
sequence_len
*
get_tensor_model_parallel_world_size
()
ag_gemm_kernel
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len
*
batch_size
,
output_hidden_size
,
input_hidden_size
,
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
output
=
ag_gemm_kernel
.
forward
(
input
,
weight
,
output
=
fw_ag_gemm_op
.
forward
(
input
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
.
t
().
contiguous
()
if
transpose_weight
else
weight
,
bias
=
bias
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
)
output
=
output
.
view
(
sequence_len
*
world_size
,
batch_size
,
-
1
)
else
:
output_buf
=
torch
.
empty
([
sequence_len
*
batch_size
,
output_hidden_size
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
use_fp8_gemm
=
False
,
)
output
=
gemm_only_op
.
forward
(
input
,
weight
,
bias
=
bias
,
output_buf
=
output_buf
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
output
=
torch
.
matmul
(
input
,
weight
.
t
())
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
,
batch_size
,
-
1
)
return
output
...
...
@@ -248,6 +230,7 @@ class AGLinear(torch.autograd.Function):
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
transpose_weight
=
ctx
.
transpose_weight
bw_gemm_rs_op
=
ctx
.
bw_gemm_rs_op
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
...
...
@@ -275,25 +258,23 @@ class AGLinear(torch.autograd.Function):
total_input
=
input
if
ctx
.
sequence_parallel
:
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
input_hidden_size
=
weight
.
size
(
-
1
)
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
if
bw_gemm_rs_op
is
None
:
input_hidden_size
=
weight
.
size
(
-
1
)
bw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
input_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output
=
grad_output
.
view
(
sequence_len
*
batch_size
,
output_hidden_size
)
gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
input_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
)
grad_input
=
gemm_rs_op
.
forward
(
grad_output
,
grad_input
=
bw_gemm_rs_op
.
forward
(
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
),
weight
if
transpose_weight
else
weight
.
t
().
contiguous
(),
bias
=
None
,
input_scale
=
None
,
...
...
@@ -310,9 +291,6 @@ class AGLinear(torch.autograd.Function):
if
ctx
.
sequence_parallel
and
wgrad_compute
:
handle
.
wait
()
if
ctx
.
sequence_parallel
:
grad_output
=
grad_output
.
view
(
sequence_len
,
batch_size
,
output_hidden_size
)
if
wgrad_compute
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
total_input
...
...
@@ -323,8 +301,6 @@ class AGLinear(torch.autograd.Function):
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
...
...
@@ -365,10 +341,10 @@ class AGLinear(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
allreduce_dgrad
:
if
not
ctx
.
sequence_parallel
and
ctx
.
allreduce_dgrad
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
ag_linear
(
...
...
@@ -381,6 +357,8 @@ def ag_linear(
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
wgrad_deferral_limit
:
Optional
[
int
]
=
0
,
transpose_weight
:
Optional
[
bool
]
=
False
,
fw_ag_gemm_op
=
None
,
bw_gemm_rs_op
=
None
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
...
...
@@ -442,6 +420,11 @@ def ag_linear(
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
fw_ag_gemm_op: flux AGKernel for forward.
bw_gemm_rs_op: flux GemmRS for backward.
"""
args
=
[
...
...
@@ -454,6 +437,8 @@ def ag_linear(
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
,
fw_ag_gemm_op
,
bw_gemm_rs_op
,
]
if
not
ag_linear
.
warned
:
...
...
@@ -533,8 +518,6 @@ class LinearRS(torch.autograd.Function):
fast_accum
=
False
,
)
output
=
output
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
# output = torch.matmul(input, weight.t())
# return _reduce_scatter_along_first_dim(output)
else
:
output_buf
=
torch
.
empty
(
[
sequence_len
*
batch_size
,
output_hidden_size
],
...
...
@@ -791,7 +774,7 @@ def linear_rs(
linear_rs
.
warned
=
False
def
parallel_linear_init_wrapper
(
fn
):
def
column_
parallel_linear_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -809,6 +792,13 @@ def parallel_linear_init_wrapper(fn):
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
if
self
.
sequence_parallel
:
self
.
previous_flux_params
=
(
None
,)
*
5
self
.
fw_ag_gemm_op
=
None
self
.
bw_gemm_rs_op
=
None
else
:
return
wrapper
...
...
@@ -884,6 +874,50 @@ class ColumnParallelLinearPatch(torch.nn.Module):
# Matrix multiply.
if
self
.
use_flux
:
assert
HAS_FLUX
,
"flux is NOT installed"
sequence_len
,
batch_size
,
input_hidden_size
=
input_parallel
.
size
()
output_hidden_size
=
weight
.
size
(
0
)
world_size
=
get_tensor_model_parallel_world_size
()
if
self
.
sequence_parallel
:
current_fw_params
=
(
sequence_len
,
batch_size
,
input_hidden_size
,
output_hidden_size
,
input_parallel
.
dtype
)
if
(
self
.
fw_ag_gemm_op
is
None
or
current_flux_params
!=
self
.
previous_flux_params
)
self
.
fw_ag_gemm_op
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
1
,
# torch.distributed.get_world_size() // torch.cuda.device_count(),
sequence_len
*
batch_size
*
world_size
,
output_hidden_size
,
input_hidden_size
,
input_parallel
.
dtype
,
output_dtype
=
input_parallel
.
dtype
,
transpose_weight
=
self
.
flux_transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
self
.
bw_gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
1
,
# world_size // torch.cuda.device_count(),
sequence_len
*
batch_size
,
input_hidden_size
,
input_parallel
.
dtype
,
input_parallel
.
dtype
,
transpose_weight
=
self
.
flux_transpose_weight
,
fuse_reduction
=
False
)
self
.
previous_flux_params
=
current_fw_params
self
.
_forward_impl
=
ag_linear
elif
not
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
...
...
@@ -903,7 +937,11 @@ class ColumnParallelLinearPatch(torch.nn.Module):
"wgrad_deferral_limit"
:
self
.
config
.
wgrad_deferral_limit
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
}
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
})
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
,
"fw_ag_gemm_op"
:
self
.
fw_ag_gemm_op
,
"bw_gemm_rs_op"
:
self
.
bw_gemm_rs_op
,
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
...
...
@@ -922,6 +960,27 @@ class ColumnParallelLinearPatch(torch.nn.Module):
return
output
,
output_bias
def
row_parallel_linear_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# flux params
self
.
use_flux
=
False
if
"use_flux"
in
kwargs
:
self
.
use_flux
=
kwargs
[
"use_flux"
]
elif
hasattr
(
self
.
config
,
"use_flux"
):
self
.
use_flux
=
self
.
config
.
use_flux
self
.
flux_transpose_weight
=
False
if
"flux_transpose_weight"
in
kwargs
:
self
.
flux_transpose_weight
=
kwargs
[
"flux_transpose_weight"
]
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
return
wrapper
class
RowParallelLinearPatch
(
torch
.
nn
.
Module
):
"""Linear layer with row parallelism.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment