Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
28830402
Unverified
Commit
28830402
authored
Apr 26, 2022
by
LuGY
Committed by
GitHub
Apr 26, 2022
Browse files
[example] change qkv processing (#870)
parent
96211c2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
model_zoo/gpt/gpt.py
model_zoo/gpt/gpt.py
+7
-6
No files found.
model_zoo/gpt/gpt.py
View file @
28830402
...
...
@@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module):
def
forward
(
self
,
x
,
attention_mask
=
None
):
qkv
=
self
.
query_key_value
(
x
)
all_head_size
=
qkv
.
shape
[
-
1
]
//
3
num_attention_heads
=
divide
(
all_head_size
,
self
.
attention_head_size
)
new_qkv_shape
=
qkv
.
shape
[:
-
1
]
+
\
(
num_attention_heads
,
3
*
self
.
attention_head_size
)
qkv
=
qkv
.
view
(
new_qkv_shape
)
qkv
=
qkv
.
permute
((
0
,
2
,
1
,
3
))
q
,
k
,
v
=
torch
.
chunk
(
qkv
,
3
,
dim
=-
1
)
all_head_size
=
q
.
shape
[
-
1
]
num_attention_heads
=
divide
(
all_head_size
,
self
.
attention_head_size
)
new_shape
=
q
.
shape
[:
-
1
]
+
\
(
num_attention_heads
,
self
.
attention_head_size
)
q
=
q
.
view
(
new_shape
).
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
k
=
k
.
view
(
new_shape
).
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
v
=
v
.
view
(
new_shape
).
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
x
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment