Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
abf04a56
Unverified
Commit
abf04a56
authored
Nov 20, 2023
by
Shijie
Committed by
GitHub
Nov 19, 2023
Browse files
fix flash ce mp large vocab (#673)
parent
db2f8069
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
1 deletion
+88
-1
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+2
-1
tests/losses/test_cross_entropy_parallel_large_vocab.py
tests/losses/test_cross_entropy_parallel_large_vocab.py
+86
-0
No files found.
flash_attn/ops/triton/cross_entropy.py
View file @
abf04a56
...
@@ -197,8 +197,9 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -197,8 +197,9 @@ class CrossEntropyLoss(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
# -0.1 * sum logit / total_classes.
if
world_size
>
1
:
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
lse_allgather
=
torch
.
empty
(
world_size
*
n_splits
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
if
n_splits
>
1
:
losses
=
losses
.
sum
(
dim
=
0
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
)
...
...
tests/losses/test_cross_entropy_parallel_large_vocab.py
0 → 100644
View file @
abf04a56
# Run test with:
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel_large_vocab.py
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
256
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
# @pytest.mark.parametrize("world_size", [2])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
(
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
))
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
partition_vocab_size
=
vocab_size
//
world_size
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
(
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
*
10
).
requires_grad_
()
x
=
(
tensor_parallel
.
scatter_to_tensor_model_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
reduction
=
"none"
)
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
reduction
=
"none"
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
out_pt
+=
lse_square_scale
*
lse_pt
.
square
()
out_pt
.
masked_fill_
(
y
==
-
100
,
0.0
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[:,
(
rank
*
partition_vocab_size
)
:
(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
,
)
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment