Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
226a1b72
Commit
226a1b72
authored
Dec 23, 2022
by
Tri Dao
Browse files
Implement TensorParallel for FusedDense and FusedDenseGeluDense
parent
dff68c2b
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
502 additions
and
59 deletions
+502
-59
csrc/fused_dense_lib/fused_dense.cpp
csrc/fused_dense_lib/fused_dense.cpp
+13
-0
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+2
-2
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+233
-57
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+74
-0
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+180
-0
No files found.
csrc/fused_dense_lib/fused_dense.cpp
View file @
226a1b72
...
...
@@ -2,6 +2,7 @@
// We make it work for bfloat16
#include <torch/extension.h>
#include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <stdio.h>
...
...
@@ -50,6 +51,10 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
CHECK_SHAPE
(
input
,
batch_size
,
in_features
);
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
opts
);
...
...
@@ -104,6 +109,10 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
CHECK_SHAPE
(
bias
,
out_features
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
// create output/workspace tensor
auto
opts
=
input
.
options
();
auto
output
=
at
::
empty
({
batch_size
,
out_features
},
opts
);
...
...
@@ -153,6 +162,10 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
CHECK_SHAPE
(
d_output
,
batch_size
,
out_features
);
CHECK_SHAPE
(
gelu_in
,
batch_size
,
in_features
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
weight
.
get_device
()};
// create output/workspace tensor
auto
opts
=
weight
.
options
();
auto
d_bias
=
at
::
empty
({
in_features
},
opts
);
...
...
flash_attn/modules/mlp.py
View file @
226a1b72
...
...
@@ -5,9 +5,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
try
:
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
except
ImportError
:
FusedDenseGeluDense
=
None
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
=
None
,
None
class
Mlp
(
nn
.
Module
):
...
...
flash_attn/ops/fused_dense.py
View file @
226a1b72
This diff is collapsed.
Click to expand it.
flash_attn/utils/distributed.py
0 → 100644
View file @
226a1b72
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
if
"reduce_scatter_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, oes does support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
# Raw operation, oes does support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
output
=
torch
.
empty
(
input_
.
shape
[
0
]
//
world_size
,
*
input_
.
shape
[
1
:],
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
handle
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
,
input_
.
contiguous
(),
group
=
process_group
,
async_op
=
async_op
)
return
output
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_gather_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
reduce_scatter_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
all_gather
=
AllGatherFunc
.
apply
class
ReduceScatterFunc
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
reduce_scatter_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
grad_input
,
_
=
all_gather_raw
(
grad_output
,
ctx
.
process_group
)
return
grad_input
,
None
# Supports autograd, but does not support async
reduce_scatter
=
ReduceScatterFunc
.
apply
tests/ops/test_fused_dense_parallel.py
0 → 100644
View file @
226a1b72
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedDenseGeluDense
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
ParallelFusedDenseGeluDense
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [8])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('in_features', [4096])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
model
=
ColumnParallelLinear
(
in_features
,
out_features
,
parallel_state
.
get_tensor_model_parallel_group
(),
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
out
=
model
(
x
)
out_pt
=
model_pt
(
x_pt
)
assert
torch
.
allclose
(
out
,
out_pt
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out_pt
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
])
parallel_state
.
destroy_model_parallel
()
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias2
and
rank
==
0
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out
=
model
(
x
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
[:,
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias2
and
rank
==
0
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment