Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
14ccf598
Unverified
Commit
14ccf598
authored
Oct 09, 2021
by
Masaki Kozuki
Committed by
GitHub
Oct 08, 2021
Browse files
Remove `custom_fwd`/`custom_bwd` from fused softmax (#1188)
* run backward * remove custom_fwd/custom_bwd
parent
3ad9db2a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
21 deletions
+49
-21
apex/transformer/functional/fused_softmax.py
apex/transformer/functional/fused_softmax.py
+0
-3
tests/L0/run_transformer/test_fused_softmax.py
tests/L0/run_transformer/test_fused_softmax.py
+49
-18
No files found.
apex/transformer/functional/fused_softmax.py
View file @
14ccf598
...
@@ -37,7 +37,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -37,7 +37,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return
softmax_results
return
softmax_results
@
staticmethod
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
def
backward
(
ctx
,
output_grads
):
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
...
@@ -68,7 +67,6 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
...
@@ -68,7 +67,6 @@ def scaled_upper_triang_masked_softmax(inputs, _, scale):
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
...
@@ -78,7 +76,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -78,7 +76,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
softmax_results
return
softmax_results
@
staticmethod
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
def
backward
(
ctx
,
output_grads
):
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
import
scaled_masked_softmax_cuda
...
...
tests/L0/run_transformer/test_fused_softmax.py
View file @
14ccf598
...
@@ -12,8 +12,7 @@ from apex.transformer.functional import FusedScaleMaskSoftmax
...
@@ -12,8 +12,7 @@ from apex.transformer.functional import FusedScaleMaskSoftmax
def
attention_mask_func
(
attention_scores
,
attention_mask
):
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
.
masked_fill
(
attention_mask
,
-
10000.0
)
return
attention_scores
autocast_dtypes
=
(
torch
.
half
,
torch
.
bfloat16
)
if
torch
.
cuda
.
is_bf16_supported
()
else
(
torch
.
half
,)
autocast_dtypes
=
(
torch
.
half
,
torch
.
bfloat16
)
if
torch
.
cuda
.
is_bf16_supported
()
else
(
torch
.
half
,)
...
@@ -61,11 +60,19 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
...
@@ -61,11 +60,19 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
return
return
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
padding
)
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
padding
)
attention_scores
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
attention_scores_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
).
requires_grad_
(
True
)
with
torch
.
no_grad
():
attention_scores_1
=
attention_scores_0
.
clone
().
requires_grad_
(
True
)
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
),
device
=
"cuda"
).
bool
()
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
),
device
=
"cuda"
).
bool
()
reference
=
fused_fn
(
attention_scores
,
mask
)
expected
=
fused_fn
(
attention_scores_0
,
mask
)
actual
=
torch_fn
(
attention_scores
,
mask
)
actual
=
torch_fn
(
attention_scores_1
,
mask
)
torch
.
testing
.
assert_allclose
(
actual
,
reference
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
rand_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
expected
.
backward
(
g0
)
actual
.
backward
(
g1
)
def
test_autocast_fused_scale_mask_softmax
(
self
):
def
test_autocast_fused_scale_mask_softmax
(
self
):
for
dtype
in
autocast_dtypes
:
for
dtype
in
autocast_dtypes
:
...
@@ -74,16 +81,24 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
...
@@ -74,16 +81,24 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16
=
dtype
==
torch
.
bfloat16
input_in_bf16
=
dtype
==
torch
.
bfloat16
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
padding
)
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
padding
)
attention_scores
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
()
attention_scores_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
().
requires_grad_
(
True
)
with
torch
.
no_grad
():
attention_scores_1
=
attention_scores_0
.
clone
().
to
(
dtype
).
requires_grad_
(
True
)
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
)).
bool
().
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
(
4
,
1
,
24
,
24
)).
bool
().
cuda
()
expected
=
torch_fn
(
attention_scores_1
,
mask
)
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
actual
=
fused_fn
(
attention_scores
,
mask
)
actual
=
fused_fn
(
attention_scores
_0
,
mask
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
with
torch
.
no_grad
():
expected
=
torch_fn
(
attention_scores
.
to
(
dtype
),
mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
rand_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
expected
.
backward
(
g0
)
actual
.
backward
(
g1
)
def
test_fused_upper_triangle_mask_softmax
(
self
):
def
test_fused_upper_triangle_mask_softmax
(
self
):
"""
"""
attn_weights.shape: [4, 12, 24, 24]
attn_weights.shape: [4, 12, 24, 24]
...
@@ -108,14 +123,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
...
@@ -108,14 +123,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
causal
)
input_in_fp16
,
input_in_bf16
,
scale
,
softmax_in_fp32
,
AttnMaskType
.
causal
)
attn_weights
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
attn_weights_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
to
(
device
=
"cuda"
,
dtype
=
dtype
).
requires_grad_
(
True
)
with
torch
.
no_grad
():
attn_weights_1
=
attn_weights_0
.
clone
().
requires_grad_
(
True
)
total_mask
=
(
~
(
total_mask
=
(
~
(
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
).
unsqueeze
(
0
).
unsqueeze
(
0
))
).
unsqueeze
(
0
).
unsqueeze
(
0
))
total_mask
=
total_mask
.
repeat
((
4
,
1
,
1
,
1
))
total_mask
=
total_mask
.
repeat
((
4
,
1
,
1
,
1
))
reference
=
fused_fn
(
attn_weights
,
total_mask
)
expected
=
fused_fn
(
attn_weights_0
,
total_mask
)
actual
=
torch_fn
(
attn_weights
,
total_mask
)
actual
=
torch_fn
(
attn_weights_1
,
total_mask
)
torch
.
testing
.
assert_allclose
(
actual
,
reference
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
randn_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
actual
.
backward
(
g0
)
expected
.
backward
(
g1
)
def
test_autocast_fused_upper_triangle_mask_softmax
(
self
):
def
test_autocast_fused_upper_triangle_mask_softmax
(
self
):
for
dtype
in
autocast_dtypes
:
for
dtype
in
autocast_dtypes
:
...
@@ -124,14 +147,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
...
@@ -124,14 +147,22 @@ class TestFusedScaleMaskSoftmax(unittest.TestCase):
input_in_bf16
=
dtype
==
torch
.
bfloat16
input_in_bf16
=
dtype
==
torch
.
bfloat16
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
fused_fn
,
torch_fn
=
self
.
_setup_fused_softmax
(
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
causal
)
input_in_fp16
,
input_in_bf16
,
attn_mask_type
=
AttnMaskType
.
causal
)
attn_weights
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
()
attn_weights_0
=
torch
.
randn
((
4
,
12
,
24
,
24
)).
cuda
().
requires_grad_
(
True
)
with
torch
.
no_grad
():
attn_weights_1
=
attn_weights_0
.
clone
().
to
(
dtype
).
requires_grad_
(
True
)
total_mask
=
(
~
(
total_mask
=
(
~
(
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
torch
.
tril
(
torch
.
randn
((
24
,
24
),
device
=
"cuda"
)).
bool
()
).
unsqueeze
(
0
).
unsqueeze
(
0
))
).
unsqueeze
(
0
).
unsqueeze
(
0
))
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
with
torch
.
cuda
.
amp
.
autocast
(
dtype
=
dtype
):
actual
=
fused_fn
(
attn_weights
,
total_mask
)
actual
=
fused_fn
(
attn_weights
_0
,
total_mask
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
self
.
assertEqual
(
actual
.
dtype
,
dtype
)
with
torch
.
no_grad
():
expected
=
torch_fn
(
attn_weights_1
,
total_mask
)
expected
=
torch_fn
(
attn_weights
.
to
(
dtype
),
total_mask
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
torch
.
testing
.
assert_allclose
(
actual
,
expected
)
g0
=
torch
.
randn_like
(
actual
)
with
torch
.
no_grad
():
g1
=
g0
.
clone
()
actual
.
backward
(
g0
)
expected
.
backward
(
g1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment