Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eaac03ae
Commit
eaac03ae
authored
Mar 09, 2022
by
ExtremeViscent
Committed by
Frank Lee
Mar 11, 2022
Browse files
[formart] format fixed for kernel\cuda_native codes (#335)
parent
00670c87
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
17 deletions
+15
-17
colossalai/kernel/cuda_native/layer_norm.py
colossalai/kernel/cuda_native/layer_norm.py
+4
-4
colossalai/kernel/cuda_native/multihead_attention.py
colossalai/kernel/cuda_native/multihead_attention.py
+11
-13
No files found.
colossalai/kernel/cuda_native/layer_norm.py
View file @
eaac03ae
colossalai/kernel/cuda_native/multihead_attention.py
View file @
eaac03ae
...
@@ -9,7 +9,7 @@ from torch.autograd import Function
...
@@ -9,7 +9,7 @@ from torch.autograd import Function
def
check_config
(
config
):
def
check_config
(
config
):
if
config
.
hidden_size
%
config
.
nhead
!=
0
:
if
config
.
hidden_size
%
config
.
nhead
!=
0
:
raise
Exception
(
f
"hidden_size % nhead != 0"
)
raise
Exception
(
"hidden_size % nhead != 0"
)
factor
=
8
if
config
.
fp16
else
4
factor
=
8
if
config
.
fp16
else
4
upbound
=
factor
*
1024
*
4
upbound
=
factor
*
1024
*
4
...
@@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module):
...
@@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
in_proj_weight
.
copy_
(
self
.
in_proj_weight
.
copy_
(
attn_qkvw_global
.
view
(
3
,
hs
,
hs
)[
:,
attn_qkvw_global
.
view
(
3
,
hs
,
hs
)[
int
(
hs
*
rank_in_pg
/
:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
),
self
.
pg_size
),
:])
:])
self
.
in_proj_bias
.
copy_
(
self
.
in_proj_bias
.
copy_
(
attn_qkvb_global
.
view
(
3
,
hs
)[:,
attn_qkvb_global
.
view
(
3
,
hs
)[
int
(
hs
*
rank_in_pg
/
:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)])
self
.
pg_size
)])
attn_ow_global
=
torch
.
empty
(
hs
,
hs
)
attn_ow_global
=
torch
.
empty
(
hs
,
hs
)
nn
.
init
.
xavier_uniform_
(
attn_ow_global
,
1.0
)
nn
.
init
.
xavier_uniform_
(
attn_ow_global
,
1.0
)
...
@@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module):
...
@@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module):
torch
.
distributed
.
broadcast
(
attn_ow_global
,
src
=
0
,
group
=
self
.
pg
)
torch
.
distributed
.
broadcast
(
attn_ow_global
,
src
=
0
,
group
=
self
.
pg
)
attn_ow_global
=
attn_ow_global
.
cpu
()
attn_ow_global
=
attn_ow_global
.
cpu
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
out_proj_weight
.
copy_
(
attn_ow_global
[:,
self
.
out_proj_weight
.
copy_
(
attn_ow_global
[
int
(
hs
*
rank_in_pg
/
:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)])
self
.
pg_size
)])
else
:
else
:
attn_qkvw
=
self
.
in_proj_weight
.
view
(
-
1
,
hs
)
attn_qkvw
=
self
.
in_proj_weight
.
view
(
-
1
,
hs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment