Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1ec09ebd
Commit
1ec09ebd
authored
Jan 01, 2023
by
Tri Dao
Browse files
[FusedDense] Limit matrix dims to 2M (instead of 64k)
parent
714c1b4f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+9
-9
No files found.
flash_attn/ops/fused_dense.py
View file @
1ec09ebd
...
...
@@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function):
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
:
handle_x
.
wait
()
batch_shape
=
total_x
.
shape
[:
-
1
]
batch_shape
,
n
=
total_x
.
shape
[:
-
1
]
,
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
'fused_dense only supports matrix dims <= 2M'
)
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
...
...
@@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function):
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
if
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
dtype_eligible
:
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
)
else
:
assert
process_group
is
None
...
...
@@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if
min
(
batch_dim
,
n
,
*
weight1
.
shape
,
*
weight2
.
shape
)
>
65535
*
32
:
raise
RuntimeError
(
'fused_dense only supports matrix dims <= 2M'
)
if
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
...
...
@@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func(
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
):
batch_dim
=
x
.
shape
[:
-
1
].
numel
()
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
batch_dim
<=
64
*
1024
and
dtype_eligible
):
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
):
return
FusedDenseGeluDenseFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment