Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
304263c2
Unverified
Commit
304263c2
authored
Mar 18, 2022
by
ver217
Committed by
GitHub
Mar 18, 2022
Browse files
fix gpt attention mask (#461)
parent
fc8e6db0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
3 deletions
+24
-3
model_zoo/gpt/gpt.py
model_zoo/gpt/gpt.py
+24
-3
No files found.
model_zoo/gpt/gpt.py
View file @
304263c2
...
...
@@ -21,6 +21,7 @@ __all__ = [
@
LAYERS
.
register_module
class
GPTEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
embedding_dim
:
int
,
vocab_size
:
int
,
...
...
@@ -56,6 +57,7 @@ class GPTEmbedding(nn.Module):
@
LAYERS
.
register_module
class
GPTSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
...
...
@@ -70,7 +72,8 @@ class GPTSelfAttention(nn.Module):
self
.
query_key_value
=
col_nn
.
Linear
(
dim
,
3
*
dim
,
dtype
=
dtype
,
bias
=
bias
)
if
fuse_scale_mask_softmax
:
from
colossalai.kernel
import
FusedScaleMaskSoftmax
from
colossalai.kernel.cuda_native.scaled_softmax
import
AttnMaskType
from
colossalai.kernel.cuda_native.scaled_softmax
import
\
AttnMaskType
self
.
softmax
=
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
...
...
@@ -113,7 +116,7 @@ class GPTSelfAttention(nn.Module):
x
=
torch
.
matmul
(
x
,
v
)
x
=
x
.
transpose
(
1
,
2
)
new_context_layer_shape
=
x
.
size
()[:
-
2
]
+
(
all_head_size
,
)
new_context_layer_shape
=
x
.
size
()[:
-
2
]
+
(
all_head_size
,)
x
=
x
.
reshape
(
new_context_layer_shape
)
x
=
self
.
dense
(
x
)
...
...
@@ -124,6 +127,7 @@ class GPTSelfAttention(nn.Module):
@
LAYERS
.
register_module
class
GPTMLP
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
mlp_ratio
:
float
,
...
...
@@ -148,6 +152,7 @@ class GPTMLP(nn.Module):
@
LAYERS
.
register_module
class
GPTBlock
(
CheckpointModule
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
...
...
@@ -194,6 +199,7 @@ class GPTBlock(CheckpointModule):
@
LAYERS
.
register_module
class
GPTLMHead
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
vocab_size
:
int
,
...
...
@@ -214,6 +220,7 @@ class GPTLMHead(nn.Module):
@
LOSSES
.
register_module
class
GPTLMLoss
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
loss
=
col_nn
.
CrossEntropyLoss
()
...
...
@@ -227,6 +234,7 @@ class GPTLMLoss(nn.Module):
@
MODELS
.
register_module
class
GPT
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
:
int
=
50304
,
max_position_embeddings
:
int
=
1024
,
...
...
@@ -279,6 +287,18 @@ class GPT(nn.Module):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
):
x
,
attention_mask
=
self
.
embed
(
input_ids
,
attention_mask
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if
attention_mask
is
not
None
:
batch_size
=
x
.
shape
[
0
]
attention_mask
=
attention_mask
.
view
(
batch_size
,
-
1
)
attention_mask
=
col_nn
.
partition_batch
(
attention_mask
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask
=
attention_mask
.
to
(
dtype
=
x
.
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
for
block
in
self
.
blocks
:
x
,
attention_mask
=
block
(
x
,
attention_mask
)
...
...
@@ -288,6 +308,7 @@ class GPT(nn.Module):
class
PipelineGPT
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
:
int
=
50304
,
max_position_embeddings
:
int
=
1024
,
...
...
@@ -355,7 +376,7 @@ class PipelineGPT(nn.Module):
attention_mask
=
attention_mask
.
view
(
batch_size
,
-
1
)
attention_mask
=
col_nn
.
partition_batch
(
attention_mask
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask
=
attention_mask
.
to
(
dtype
=
x
.
dtype
)
# fp16 compatibility
attention_mask
=
attention_mask
.
to
(
dtype
=
x
.
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
for
block
in
self
.
blocks
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment