Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
226a1b72
Commit
226a1b72
authored
Dec 23, 2022
by
Tri Dao
Browse files
Implement TensorParallel for FusedDense and FusedDenseGeluDense
parent
dff68c2b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
502 additions
and
59 deletions
+502
-59
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+13
-0
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-2
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+233
-57
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+74
-0
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+180
-0
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
226a1b72
...
...
@@ -2,6 +2,7 @@
// We make it work for bfloat16
#include <torch/extension.h>
#include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <stdio.h>
...
...
@@ -50,6 +51,10 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
...
...
@@ -104,6 +109,10 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
CHECK_SHAPE
(
bias
,
out_features
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
...
...
@@ -153,6 +162,10 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
weight
.
get_device
()};
// create output/workspace tensor
auto
opts
=
weight
.
options
();
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
...
...
flash_attn/modules/mlp.py
View file @
226a1b72
...
...
@@ -5,9 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
try
:
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
except
ImportError
:
FusedDenseGeluDense
=
None
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
=
None
,
None
class
Mlp
(
nn
.
Module
):
...
...
flash_attn/ops/fused_dense.py
View file @
226a1b72
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# Copyright (c) 2022, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
# import fused_dense_cuda # from apex
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.gelu_activation
import
gelu_bwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
reduce_scatter_raw
,
reduce_scatter
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
):
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
,
process_group
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather_raw of x before doing the matmul.
"""
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
]]
bias
=
bias
.
to
(
dtype
=
dtype
)
if
bias
is
not
None
else
None
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
x
=
x
.
contiguous
()
weight
=
weight
.
contiguous
()
ctx
.
save_for_backward
(
x
,
weight
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
F
.
linear
(
x
,
weight
,
bias
)
if
process_group
is
not
None
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
else
:
total_x
=
x
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
...
...
@@ -39,37 +59,56 @@ class FusedDenseFunc(torch.autograd.Function):
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
x
,
weight
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
process_group
=
ctx
.
process_group
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
else
:
weight
,
=
ctx
.
saved_tensors
total_x
=
None
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
if
ctx
.
needs_input_grad
[
1
]:
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_output
,
weight
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape_as
(
x
)
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
grad_input
,
handle_grad_input
=
reduce_scatter_raw
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
:
handle_x
.
wait
()
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_output
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight
=
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
):
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
)
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
)
else
:
assert
process_group
is
None
out
=
F
.
linear
(
x
,
weight
,
bias
)
return
out
if
not
return_residual
else
(
out
,
x
)
...
...
@@ -81,17 +120,69 @@ class FusedDense(nn.Linear):
super
().
__init__
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
return_residual
=
return_residual
def
forward
(
self
,
x
,
process_group
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
"""
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
,
process_group
=
process_group
)
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
world_size
!=
0
:
raise
ValueError
(
f
'out_features (
{
out_features
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
super
().
__init__
(
in_features
,
out_features
//
world_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
def
forward
(
self
,
x
):
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
return_residual
=
self
.
return_residual
)
"""
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
x before doing the matmul.
"""
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
)
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
world_size
!=
0
:
raise
ValueError
(
f
'in_features (
{
in_features
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
# Only rank 0 will have bias
super
().
__init__
(
in_features
//
world_size
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
def
forward
(
self
,
x
):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out
=
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
)
return
reduce_scatter
(
out
,
self
.
process_group
)
class
FusedDenseGeluDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_gelu_in
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
):
"""checkpoint_lvl:
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
...
...
@@ -102,28 +193,34 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
x
,
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
if
not
save_
gelu_in
:
if
not
save_
pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
x
=
x
.
contiguous
()
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
if
process_group
is
not
None
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
else
:
total_x
=
x
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
if
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
total_
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# gelu_in = F.linear(
total_
x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else
:
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_gelu_in
,
heuristic
)
if
save_gelu_in
:
output1
,
*
rest
=
fused_dense_cuda
.
linear_gelu_forward
(
total_x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
save_pre_act
,
heuristic
)
if
save_pre_act
:
gelu_in
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
...
...
@@ -145,22 +242,31 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
if
process_group
is
None
:
total_x
=
x
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
gelu_in
,
=
rest
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
if
process_group
is
not
None
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
total_
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
else
:
output1
,
gelu_in
=
fused_dense_cuda
.
linear_gelu_forward
(
x
.
reshape
(
batch_dim
,
n
),
weight1
,
bias1
,
True
,
ctx
.
heuristic
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
True
,
ctx
.
heuristic
)
grad_output
=
grad_output
.
reshape
(
batch_dim
,
grad_output
.
shape
[
-
1
])
...
...
@@ -178,13 +284,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
with
torch
.
jit
.
fuser
(
'fuser2'
):
grad_gelu
=
gelu_bwd
(
grad_output1
,
gelu_in
)
if
ctx
.
needs_input_grad
[
1
]:
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
x
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
# just compute gelu grad
...
...
@@ -193,26 +292,49 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
)
if
not
ctx
.
needs_input_grad
[
2
]:
grad_bias1
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
x
.
reshape
(
batch_dim
,
n
).
t
())
else
:
grad_weight1
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
not
ctx
.
return_residual
:
grad_input
=
F
.
linear
(
grad_gelu
,
weight1
.
t
())
else
:
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
n
),
grad_gelu
,
weight1
)
grad_input
=
grad_input
.
reshape_as
(
x
)
grad_input
=
torch
.
addmm
(
grad_input
.
reshape
(
batch_dim
,
grad_input
.
shape
[
-
1
]),
grad_gelu
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
grad_input
,
handle_grad_input
=
reduce_scatter_raw
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
grad_input
=
None
return
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
:
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_gelu
,
ctx
.
needs_input_grad
[
2
]
)
else
:
grad_weight1
=
None
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
:
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
else
:
grad_weight1
=
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
)
def
fused_dense_gelu_dense_func
(
x
:
Tensor
,
weight1
:
Tensor
,
weight2
:
Tensor
,
bias1
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
save_gelu_in
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
@@ -222,9 +344,10 @@ def fused_dense_gelu_dense_func(
and
dtype_eligible
):
return
FusedDenseGeluDenseFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_
gelu_in
,
return_residual
,
checkpoint_lvl
,
heuristic
save_
pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
)
else
:
assert
process_group
is
None
gelu_in
=
F
.
linear
(
x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
...
...
@@ -237,6 +360,10 @@ class FusedDenseGeluDense(nn.Module):
bias2
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
...
...
@@ -261,9 +388,58 @@ class FusedDenseGeluDense(nn.Module):
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
process_group
=
None
):
out
=
fused_dense_gelu_dense_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
return_residual
=
self
.
return_residual
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
,
process_group
=
process_group
)
if
self
.
return_residual
:
out
,
x
=
out
if
process_group
is
not
None
:
out
=
reduce_scatter
(
out
,
process_group
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelFusedDenseGeluDense
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
process_group
is
not
None
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
out_features
is
None
:
out_features
=
in_features
self
.
process_group
=
process_group
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
return
fused_dense_gelu_dense_func
(
out
=
fused_dense_gelu_dense_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_
gelu_in
=
self
.
training
,
return_residual
=
self
.
return_residua
l
,
c
he
ckpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
save_
pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lv
l
,
he
uristic
=
self
.
heuristic
,
process_group
=
self
.
process_group
)
return
reduce_scatter
(
out
,
self
.
process_group
)
flash_attn/utils/distributed.py
0 → 100644
View file @
226a1b72
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, oes does support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, oes does support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
tests/ops/test_fused_dense_parallel.py
0 → 100644
View file @
226a1b72
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFusedDenseGeluDense
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [8])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('in_features', [4096])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
model
=
ColumnParallelLinear
(
in_features
,
out_features
,
parallel_state
.
get_tensor_model_parallel_group
(),
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
out
=
model
(
x
)
out_pt
=
model_pt
(
x_pt
)
assert
torch
.
allclose
(
out
,
out_pt
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out_pt
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
])
parallel_state
.
destroy_model_parallel
()
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias2
and
rank
==
0
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out
=
model
(
x
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias2
and
rank
==
0
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment