Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
b0704f1d
Commit
b0704f1d
authored
May 05, 2021
by
TiagoMAntunes
Browse files
New data placement structure, speedup
parent
f165928a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
24 deletions
+27
-24
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+27
-24
No files found.
cuda/moe_compute_kernel.cu
View file @
b0704f1d
...
@@ -47,45 +47,48 @@ void column_reduce(const scalar_t * matrix, scalar_t * result,
...
@@ -47,45 +47,48 @@ void column_reduce(const scalar_t * matrix, scalar_t * result,
int
m
/* lines */
,
int
n
/* columns*/
)
{
int
m
/* lines */
,
int
n
/* columns*/
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
__align__
(
sizeof
(
scalar_t
))
unsigned
char
my_smem
[];
extern
__shared__
unsigned
char
my_smem
[];
scalar_t
*
sdata
=
reinterpret_cast
<
scalar_t
*>
(
my_smem
);
scalar_t
*
sdata
=
reinterpret_cast
<
scalar_t
*>
(
my_smem
);
unsigned
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
// normal tid
unsigned
int
real_x
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
unsigned
int
real_y
=
n
*
threadIdx
.
y
;
unsigned
int
i
=
real_x
+
real_y
;
// transposed tid for shared memory
unsigned
int
it
=
n
*
blockDim
.
y
;
int
new_tid
=
threadIdx
.
y
+
threadIdx
.
x
*
blockDim
.
y
;
unsigned
int
offset
=
it
;
// true x value in the matrix
int
real_x
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
sdata
[
tid
]
=
0
;
int
i
=
real_x
+
n
*
threadIdx
.
y
;
const
int
it
=
n
*
blockDim
.
y
;
int
offset
=
it
;
float
accumulator
=
0
;
if
(
threadIdx
.
y
<
m
&&
real_x
<
n
)
{
if
(
threadIdx
.
y
<
m
&&
real_x
<
n
)
{
// can load memory
// store all the values from this column in a warped way
// printf("tid=%d loading %d\n", tid, i);
accumulator
=
matrix
[
i
];
sdata
[
tid
]
=
matrix
[
i
];
while
(
i
+
offset
<
n
*
m
)
{
while
(
i
+
offset
<
n
*
m
)
{
// printf("tid=%d loading %d\n", tid, i+offset);
accumulator
+=
matrix
[
i
+
offset
];
sdata
[
tid
]
+=
matrix
[
i
+
offset
];
offset
+=
it
;
offset
+=
it
;
}
}
}
}
// save column reduction data in a transposed way
sdata
[
new_tid
]
=
accumulator
;
__syncthreads
();
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
y
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
threadIdx
.
y
<
s
)
{
// printf("tid=%d adding %d\n", tid, tid + blockDim.x *s);
sdata
[
tid
]
+=
sdata
[
tid
+
blockDim
.
x
*
s
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
&&
real_x
<
n
)
{
for
(
size_t
t
=
16
;
t
>
0
;
t
>>=
1
)
{
result
[
real_x
]
=
sdata
[
tid
];
if
(
tid
<
32
*
32
-
16
)
sdata
[
tid
]
+=
sdata
[
tid
+
t
];
__syncthreads
();
}
}
if
(
threadIdx
.
y
==
0
&&
real_x
<
n
)
result
[
real_x
]
=
sdata
[
new_tid
];
}
}
void
moe_cuda_expert_count_impl
(
void
moe_cuda_expert_count_impl
(
const
int
*
d_gate
,
const
int
*
d_gate
,
int
*
expert_count
,
int
*
expert_count
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment