Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
aae49663
Unverified
Commit
aae49663
authored
Nov 22, 2023
by
flybird11111
Committed by
GitHub
Nov 22, 2023
Browse files
[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)
* fix flash attn * fix fix
parent
75af66cd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
8 deletions
+16
-8
colossalai/shardformer/modeling/chatglm2.py
colossalai/shardformer/modeling/chatglm2.py
+2
-1
colossalai/shardformer/modeling/gpt2.py
colossalai/shardformer/modeling/gpt2.py
+5
-4
colossalai/shardformer/modeling/llama.py
colossalai/shardformer/modeling/llama.py
+2
-1
colossalai/shardformer/modeling/opt.py
colossalai/shardformer/modeling/opt.py
+2
-1
colossalai/shardformer/modeling/whisper.py
colossalai/shardformer/modeling/whisper.py
+4
-1
examples/language/llama2/pretrain.py
examples/language/llama2/pretrain.py
+1
-0
No files found.
colossalai/shardformer/modeling/chatglm2.py
View file @
aae49663
...
@@ -51,6 +51,7 @@ def get_flash_core_attention_forward():
...
@@ -51,6 +51,7 @@ def get_flash_core_attention_forward():
attn_mask_type
=
AttnMaskType
.
causal
attn_mask_type
=
AttnMaskType
.
causal
else
:
else
:
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
if
not
torch
.
all
(
flash_attention_mask
):
attn_mask_type
=
AttnMaskType
.
paddedcausal
attn_mask_type
=
AttnMaskType
.
paddedcausal
attention
=
ColoAttention
(
attention
=
ColoAttention
(
...
...
colossalai/shardformer/modeling/gpt2.py
View file @
aae49663
...
@@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
...
@@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
attn_mask_type
=
AttnMaskType
.
causal
attn_mask_type
=
AttnMaskType
.
causal
flash_attention_mask
=
None
flash_attention_mask
=
None
if
attention_mask
!=
None
:
if
attention_mask
!=
None
:
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
if
not
torch
.
all
(
flash_attention_mask
):
if
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_mask_type
==
AttnMaskType
.
causal
:
attn_mask_type
==
AttnMaskType
.
paddedcausal
attn_mask_type
==
AttnMaskType
.
paddedcausal
else
:
else
:
attn_mask_type
=
AttnMaskType
.
padding
attn_mask_type
=
AttnMaskType
.
padding
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
scale
=
value
.
size
(
-
1
)
**
-
0.5
scale
=
value
.
size
(
-
1
)
**
-
0.5
if
self
.
scale_attn_by_inverse_layer_idx
:
if
self
.
scale_attn_by_inverse_layer_idx
:
...
...
colossalai/shardformer/modeling/llama.py
View file @
aae49663
...
@@ -465,6 +465,7 @@ def get_llama_flash_attention_forward():
...
@@ -465,6 +465,7 @@ def get_llama_flash_attention_forward():
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
)
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
if
not
torch
.
all
(
flash_attention_mask
):
attn_mask_type
=
AttnMaskType
.
paddedcausal
attn_mask_type
=
AttnMaskType
.
paddedcausal
attention
=
ColoAttention
(
embed_dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
)
attention
=
ColoAttention
(
embed_dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
)
...
...
colossalai/shardformer/modeling/opt.py
View file @
aae49663
...
@@ -581,6 +581,7 @@ def get_opt_flash_attention_forward():
...
@@ -581,6 +581,7 @@ def get_opt_flash_attention_forward():
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
)
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
if
not
torch
.
all
(
flash_attention_mask
):
attn_mask_type
=
AttnMaskType
.
paddedcausal
attn_mask_type
=
AttnMaskType
.
paddedcausal
attention
=
ColoAttention
(
attention
=
ColoAttention
(
...
...
colossalai/shardformer/modeling/whisper.py
View file @
aae49663
...
@@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward():
...
@@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward():
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
)
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
).
contiguous
())
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
).
contiguous
())
if
not
torch
.
all
(
flash_attention_mask
):
attn_type
=
AttnMaskType
.
paddedcausal
attn_type
=
AttnMaskType
.
paddedcausal
else
:
attn_type
=
AttnMaskType
.
causal
attention
=
ColoAttention
(
attention
=
ColoAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
self
.
num_heads
,
dropout
=
self
.
dropout
,
scale
=
self
.
scaling
embed_dim
=
self
.
embed_dim
,
num_heads
=
self
.
num_heads
,
dropout
=
self
.
dropout
,
scale
=
self
.
scaling
...
...
examples/language/llama2/pretrain.py
View file @
aae49663
...
@@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
...
@@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
def
all_reduce_mean
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
all_reduce_mean
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
tensor
,
op
=
dist
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
tensor
,
op
=
dist
.
ReduceOp
.
SUM
)
tensor
=
tensor
.
data
tensor
.
div_
(
dist
.
get_world_size
())
tensor
.
div_
(
dist
.
get_world_size
())
return
tensor
return
tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment