Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
cb0daccc
Commit
cb0daccc
authored
Aug 16, 2023
by
Tri Dao
Browse files
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
parent
bcfa7c97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
9 deletions
+19
-9
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+19
-9
No files found.
flash_attn/ops/fused_dense.py
View file @
cb0daccc
...
@@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear):
...
@@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear):
process_group
:
ProcessGroup
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
world_size
!=
0
:
if
out_features
%
multiple_of
:
raise
ValueError
(
raise
ValueError
(
f
"out_features (
{
out_features
}
) must be a multiple of
{
multiple_of
}
"
)
f
"out_features (
{
out_features
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
multiple
=
out_features
//
multiple_of
)
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
super
().
__init__
(
super
().
__init__
(
in_features
,
out_features
//
world_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
in_features
,
local_multiple
*
multiple_of
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
sequence_parallel
...
@@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear):
...
@@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear):
process_group
:
ProcessGroup
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
sequence_parallel
=
True
,
sequence_parallel
=
True
,
multiple_of
=
1
,
device
=
None
,
device
=
None
,
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
world_size
!=
0
:
if
in_features
%
multiple_of
:
raise
ValueError
(
raise
ValueError
(
f
"in_features (
{
in_features
}
) must be a multiple of
{
multiple_of
}
"
)
f
"in_features (
{
in_features
}
) must be divisible by "
f
"world_size (
{
world_size
}
)"
multiple
=
in_features
//
multiple_of
)
# We want to split @multiple across world_size, but it could be an uneven split
div
=
multiple
//
world_size
mod
=
multiple
%
world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple
=
div
+
int
(
torch
.
distributed
.
get_rank
(
process_group
)
<
mod
)
# Only rank 0 will have bias
# Only rank 0 will have bias
super
().
__init__
(
super
().
__init__
(
in_features
//
world_size
,
in_features
//
world_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment