Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c50c08dc
Commit
c50c08dc
authored
May 16, 2022
by
Kai Wang (Victor Kai)
Committed by
binmakeswell
May 17, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979)
parent
f28c0213
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
30 deletions
+16
-30
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+16
-30
No files found.
colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
View file @
c50c08dc
#include <cooperative_groups.h>
#include <chrono>
#include <ctime>
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
curandStatePhilox4_32_10_t
*
curandstate
;
...
...
@@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
uint8_t
m
[
4
];
...
...
@@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
float4
*
out4
=
reinterpret_cast
<
float4
*>
(
out
);
const
float4
*
vals_float4
=
reinterpret_cast
<
const
float4
*>
(
in
);
...
...
@@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel(
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads
();
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
32
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
...
...
@@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads
();
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
if
(
y
==
0
)
tile
[
0
][
x
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
x
<
8
)
{
...
...
@@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel(
const
float
scale
=
1.
f
/
(
1.
f
-
ratio
);
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
4
>=
total_count
)
return
;
if
(
i
*
4
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
*
8
>=
total_count
)
return
;
if
(
i
*
8
>=
total_count
)
return
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
i
,
0
,
&
state
);
...
...
@@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
tile
[
0
][
threadIdx
.
y
]
=
sum
;
if
(
threadIdx
.
x
==
0
)
tile
[
0
][
threadIdx
.
y
]
=
sum
;
__syncthreads
();
if
(
threadIdx
.
y
==
0
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment