Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
343492ec
Commit
343492ec
authored
Nov 13, 2022
by
Tri Dao
Browse files
Make nccl operations async in CrossEntropyLossParallel
parent
3dda4f76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
26 deletions
+36
-26
flash_attn/losses/cross_entropy_parallel.py
flash_attn/losses/cross_entropy_parallel.py
+34
-24
tests/losses/test_cross_entropy_apex.py
tests/losses/test_cross_entropy_apex.py
+1
-1
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+1
-1
No files found.
flash_attn/losses/cross_entropy_parallel.py
View file @
343492ec
...
@@ -36,40 +36,50 @@ class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
...
@@ -36,40 +36,50 @@ class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
assert
labels
.
shape
==
(
batch
,)
assert
labels
.
shape
==
(
batch
,)
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
get_tensor_model_parallel_rank
(),
get_tensor_model_parallel_world_size
()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
if
world_size
==
1
:
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits_parallel
,
labels
,
smoothing
)
ignored_mask
=
labels
==
ignored_index
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
labels_local
=
labels
masked_labels
=
labels_local
.
clone
()
else
:
masked_labels
[
labels_mask
]
=
ignored_index
vocab_start_index
,
vocab_end_index
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
get_tensor_model_parallel_rank
(),
get_tensor_model_parallel_world_size
()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
ignored_mask
=
labels
==
ignored_index
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
masked_labels
=
labels_local
.
clone
()
masked_labels
[
labels_mask
]
=
ignored_index
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits_parallel
,
masked_labels
,
smoothing
)
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits_parallel
,
masked_labels
,
smoothing
)
assert
lse_local
.
shape
==
(
batch
,)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
masked_labels
==
ignored_index
,
0
)
losses
.
masked_fill_
(
masked_labels
==
ignored_index
,
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
handle_lse
=
torch
.
distributed
.
all_gather_into_tensor
(
group
=
get_tensor_model_parallel_group
())
lse_allgather
,
lse_local
.
contiguous
(),
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
handle_lse
.
wait
()
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
# The losses are going to be lse_local - predicted_logit, we just have to subtract
group
=
get_tensor_model_parallel_group
())
# the lse_local and add the lse (global).
# The losses are currently lse_local - predicted_logit, we just have to subtract the
rank_per_sample
=
torch
.
div
(
labels
,
partition_vocab_size
,
rounding_mode
=
'floor'
)
# lse_local and add the lse (global).
rank_per_sample
=
labels
//
partition_vocab_size
lse_local
=
lse_allgather
[
rank_per_sample
,
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
handle_losses
.
wait
()
losses
+=
lse
-
lse_local
losses
+=
lse
-
lse_local
losses
.
masked_fill_
(
ignored_mask
,
0
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
else
:
lse
=
lse_local
ctx
.
save_for_backward
(
logits_parallel
,
lse
,
labels_local
)
ctx
.
save_for_backward
(
logits_parallel
,
lse
,
labels_local
)
ctx
.
smoothing
=
smoothing
ctx
.
smoothing
=
smoothing
...
...
tests/losses/test_cross_entropy_apex.py
View file @
343492ec
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
src
.losses.cross_entropy_apex
import
CrossEntropyLossApex
from
flass_attn
.losses.cross_entropy_apex
import
CrossEntropyLossApex
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
...
tests/losses/test_cross_entropy_parallel.py
View file @
343492ec
...
@@ -10,7 +10,7 @@ import pytest
...
@@ -10,7 +10,7 @@ import pytest
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
apex.transformer
import
tensor_parallel
from
src
.losses.cross_entropy_parallel
import
CrossEntropyLossParallel
from
flash_attn
.losses.cross_entropy_parallel
import
CrossEntropyLossParallel
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment