Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
aaa14741
"INSTALL/plugin/ventoy/ventoy.json" did not exist on "0f8478fbe1ecbcfd7a1f189d1ca2a60d05cdf322"
Commit
aaa14741
authored
Nov 19, 2023
by
Tri Dao
Browse files
[CrossEntropy] Simplify the case of large vocab with Tensor Parallel
parent
abf04a56
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
98 deletions
+10
-98
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+4
-5
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+6
-7
tests/losses/test_cross_entropy_parallel_large_vocab.py
tests/losses/test_cross_entropy_parallel_large_vocab.py
+0
-86
No files found.
flash_attn/ops/triton/cross_entropy.py
View file @
aaa14741
...
...
@@ -196,18 +196,17 @@ class CrossEntropyLoss(torch.autograd.Function):
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if
n_splits
>
1
:
lse
=
torch
.
logsumexp
(
lse
,
dim
=
0
)
losses
=
losses
.
sum
(
dim
=
0
)
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
*
n_splits
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
lse_allgather
=
torch
.
empty
(
world_size
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
if
n_splits
>
1
:
losses
=
losses
.
sum
(
dim
=
0
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
handle_losses
.
wait
()
else
:
lse
=
torch
.
logsumexp
(
lse
,
dim
=
0
)
losses
=
losses
.
sum
(
dim
=
0
)
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
...
...
tests/losses/test_cross_entropy_parallel.py
View file @
aaa14741
# Run test with:
# torchrun --no_python --nproc_per_node=
4
pytest -q -s tests/losses/test_cross_entropy_parallel.py
# torchrun --no_python --nproc_per_node=
2
pytest -q -s tests/losses/test_cross_entropy_parallel.py
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
...
...
@@ -19,19 +18,19 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [
0.0
])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50264
,
128
*
1024
])
# test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50264
,
256
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
#
@pytest.mark.parametrize("world_size", [2])
#
@pytest.mark.parametrize("world_size", [1, 2])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
(
(
1e-5
,
1
e-
6
)
(
1e-5
,
2
e-
5
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
))
)
...
...
tests/losses/test_cross_entropy_parallel_large_vocab.py
deleted
100644 → 0
View file @
abf04a56
# Run test with:
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel_large_vocab.py
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
256
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
# @pytest.mark.parametrize("world_size", [2])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
(
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
))
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
partition_vocab_size
=
vocab_size
//
world_size
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
128
x_pt
=
(
torch
.
randn
(
batch_size
*
seqlen
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
*
10
).
requires_grad_
()
x
=
(
tensor_parallel
.
scatter_to_tensor_model_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
reduction
=
"none"
)
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
reduction
=
"none"
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
out_pt
+=
lse_square_scale
*
lse_pt
.
square
()
out_pt
.
masked_fill_
(
y
==
-
100
,
0.0
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[:,
(
rank
*
partition_vocab_size
)
:
(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
,
)
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment