Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3d1987d1
Commit
3d1987d1
authored
Apr 04, 2021
by
TiagoMAntunes
Browse files
New compute kernel for column reduction
parent
f957c299
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
15 deletions
+26
-15
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+26
-15
No files found.
cuda/moe_compute_kernel.cu
View file @
3d1987d1
...
@@ -46,29 +46,35 @@ __global__
...
@@ -46,29 +46,35 @@ __global__
void
column_reduce
(
const
scalar_t
*
matrix
,
scalar_t
*
result
,
void
column_reduce
(
const
scalar_t
*
matrix
,
scalar_t
*
result
,
int
m
/* lines */
,
int
n
/* columns*/
)
{
int
m
/* lines */
,
int
n
/* columns*/
)
{
extern
__shared__
float
sdata
[];
extern
__shared__
float
sdata
[];
unsigned
int
tid
=
threadIdx
.
x
;
// line
unsigned
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
// line
unsigned
int
i
=
block
Idx
.
x
+
threadIdx
.
x
*
n
;
// get to idx th line
unsigned
int
i
=
thread
Idx
.
x
*
n
+
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
// get to idx th line
unsigned
int
offset
=
0
;
unsigned
int
offset
=
0
;
unsigned
int
it
=
n
*
blockDim
.
x
;
// advanced blockDim.x threads vertically
unsigned
int
it
=
n
*
blockDim
.
x
;
// advance blockDim.x threads vertically
unsigned
int
real_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
// sum all the values from that column to fit in one single block
// sum all the values from that column to fit in one single block
sdata
[
tid
]
=
0
;
sdata
[
tid
]
=
0
;
while
(
i
+
offset
<
n
*
m
)
{
if
(
real_y
<
n
&&
threadIdx
.
x
<
m
)
// remember we only have one x block
sdata
[
tid
]
+=
matrix
[
i
+
offset
];
while
(
i
+
offset
<
n
*
m
)
{
offset
+=
it
;
sdata
[
tid
]
+=
matrix
[
i
+
offset
];
offset
+=
it
;
}
}
__syncthreads
();
__syncthreads
();
for
(
unsigned
int
s
=
1
;
tid
+
s
<
blockDim
.
x
;
s
*=
2
)
{
unsigned
int
lowest
=
blockDim
.
x
>
m
?
m
:
blockDim
.
x
;
if
(
tid
%
(
2
*
s
)
==
0
)
{
if
(
real_y
<
n
&&
threadIdx
.
x
<
m
)
sdata
[
tid
]
+=
sdata
[
tid
+
s
];
for
(
unsigned
int
s
=
1
;
threadIdx
.
x
+
s
<
lowest
;
s
*=
2
)
{
if
(
threadIdx
.
x
%
(
2
*
s
)
==
0
)
{
sdata
[
tid
]
+=
sdata
[
tid
+
s
];
}
__syncthreads
();
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
&&
real_y
<
n
)
{
result
[
real_y
]
=
sdata
[
tid
];
}
}
if
(
tid
==
0
)
{
result
[
blockIdx
.
x
]
=
sdata
[
0
];}
}
}
void
moe_cuda_expert_count_impl
(
void
moe_cuda_expert_count_impl
(
...
@@ -198,6 +204,11 @@ void moe_cuda_backward_impl(
...
@@ -198,6 +204,11 @@ void moe_cuda_backward_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
// bias
dim3
block_threads
(
32
,
32
);
dim3
grid_threads
(
1
,
out_feat
/
32
+
(
out_feat
%
32
?
1
:
0
));
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
...
@@ -235,7 +246,7 @@ void moe_cuda_backward_impl(
...
@@ -235,7 +246,7 @@ void moe_cuda_backward_impl(
if
(
has_bias
)
{
if
(
has_bias
)
{
column_reduce
column_reduce
<<<
out_feat
,
1024
,
sizeof
(
scalar_t
)
*
1024
,
smgr
->
stream
(
0
)
>>>
<<<
grid_threads
,
block_threads
,
sizeof
(
scalar_t
)
*
1024
,
smgr
->
stream
(
0
)
>>>
(
(
grad_output_buf
+
ptr
*
out_feat
,
grad_output_buf
+
ptr
*
out_feat
,
grad_bias
+
i
*
out_feat
,
grad_bias
+
i
*
out_feat
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment