Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c6ecd40a
Commit
c6ecd40a
authored
Dec 27, 2022
by
Tri Dao
Browse files
Tweak CrossEntropyLoss to take process_group in init
parent
b4018a50
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
14 deletions
+22
-14
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+4
-3
flash_attn/models/bert.py
flash_attn/models/bert.py
+1
-6
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+7
-3
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+8
-0
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+1
-1
tests/models/test_bert.py
tests/models/test_bert.py
+1
-1
No files found.
flash_attn/losses/cross_entropy.py
View file @
c6ecd40a
...
...
@@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
):
inplace_backward
=
False
,
process_group
=
None
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
...
...
@@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module):
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
self
.
process_group
=
process_group
def
forward
(
self
,
input
,
target
,
process_group
=
None
):
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
process_group
self
.
process_group
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
...
...
flash_attn/models/bert.py
View file @
c6ecd40a
...
...
@@ -28,6 +28,7 @@ from flash_attn.modules.block import Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
...
...
@@ -439,12 +440,6 @@ class BertForPreTraining(BertPreTrainedModel):
)
def
state_dict_from_pretrained
(
model_name
):
from
transformers.utils
import
WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
))
def
remap_state_dict
(
state_dict
,
config
):
# LayerNorm
def
key_mapping_ln_gamma_beta
(
key
):
...
...
flash_attn/utils/distributed.py
View file @
c6ecd40a
...
...
@@ -87,11 +87,15 @@ def sync_sequence_parallel_params(model: torch.nn.Module, process_group: Process
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def
allreduce_sequence_parallel_grad
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
'_sequence_parallel'
,
False
)}
for
_
,
p
in
sorted
(
params_seqparallel
.
items
()):
with
torch
.
no_grad
():
torch
.
distributed
.
all_reduce
(
p
.
grad
,
group
=
process_group
)
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
with
torch
.
no_grad
():
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
flash_attn/utils/pretrained.py
0 → 100644
View file @
c6ecd40a
import
torch
from
transformers.utils
import
WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
def
state_dict_from_pretrained
(
model_name
):
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
))
tests/losses/test_cross_entropy_parallel.py
View file @
c6ecd40a
...
...
@@ -24,7 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'vocab_size'
,
[
50264
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
def
test_cross_entropy_loss_
apex
(
vocab_size
,
world_size
,
smoothing
,
inplace_backward
,
dtype
):
def
test_cross_entropy_loss_
parallel
(
vocab_size
,
world_size
,
smoothing
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
((
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
((
1e-3
,
1e-4
)
if
dtype
==
torch
.
float16
else
(
1e-2
,
3e-3
)))
...
...
tests/models/test_bert.py
View file @
c6ecd40a
...
...
@@ -12,8 +12,8 @@ from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from
transformers.models.bert.modeling_bert
import
BertForPreTraining
as
BertForPreTrainingHF
from
flash_attn.models.bert
import
BertModel
,
BertForPreTraining
from
flash_attn.models.bert
import
state_dict_from_pretrained
from
flash_attn.models.bert
import
remap_state_dict
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment