Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d690c7b2
Commit
d690c7b2
authored
Dec 30, 2020
by
Rick Ho
Browse files
manual batched scatter and gather kernels
parent
254ad118
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
26 deletions
+52
-26
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+5
-1
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+1
-1
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+46
-24
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
d690c7b2
...
@@ -11,7 +11,11 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
...
@@ -11,7 +11,11 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
return
smgr
;
return
smgr
;
}
}
void
CudaStreamManager
::
sync
()
{
void
CudaStreamManager
::
sync
(
int
i
)
{
if
(
i
>
-
1
)
{
cudaStreamSynchronize
(
streams
[
i
]);
return
;
}
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
}
}
...
...
pytorch/cuda/cuda_stream_manager.h
View file @
d690c7b2
...
@@ -38,7 +38,7 @@ struct CudaStreamManager {
...
@@ -38,7 +38,7 @@ struct CudaStreamManager {
return
handles
[
idx
%
MAX_STREAMS
];
return
handles
[
idx
%
MAX_STREAMS
];
}
}
void
sync
();
void
sync
(
int
=-
1
);
};
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
d690c7b2
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
// #define MOE_BREAKDOWN
#define MOE_DEBUG
//
#define MOE_DEBUG
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
...
@@ -31,6 +31,29 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...
@@ -31,6 +31,29 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
}
}
template
<
typename
scalar_t
>
__global__
void
batch_scatter_kernel
(
int
wid
,
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
__global__
void
batch_gather_kernel
(
int
wid
,
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
...
@@ -86,22 +109,26 @@ void moe_cuda_forward_impl(
...
@@ -86,22 +109,26 @@ void moe_cuda_forward_impl(
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
}
int
*
pos
=
new
int
[
batch_size
];
int
*
d_pos
;
checkCudaErrors
(
cudaMalloc
(
&
d_pos
,
sizeof
(
int
)
*
batch_size
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_expert
);
timestamp
(
t_expert
);
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_cpy
,
t_expert
)
*
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_cpy
,
t_expert
)
*
1e6
);
1e6
);
#endif
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
batch_scatter_kernel
<
scalar_t
>
int
target_idx
=
expert_ptr
[
gate
[
i
]]
++
;
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
#ifdef MOE_DEBUG_SCATTER
input_buf
);
fprintf
(
stderr
,
"aln idx %d gate %d tgt %d
\n
"
,
i
,
gate
[
i
],
target_idx
);
h
->
sync
(
0
);
#endif
checkCudaErrors
(
cudaMemcpyAsync
(
input_buf
+
target_idx
*
in_feat
,
input
+
i
*
in_feat
,
sizeof
(
scalar_t
)
*
in_feat
,
cudaMemcpyDeviceToDevice
,
h
->
getStream
(
gate
[
i
])));
}
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
h
->
sync
();
h
->
sync
();
...
@@ -148,25 +175,16 @@ void moe_cuda_forward_impl(
...
@@ -148,25 +175,16 @@ void moe_cuda_forward_impl(
}
}
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
h
->
sync
();
timestamp
(
t_mm
);
timestamp
(
t_mm
);
fprintf
(
stderr
,
"GeMM time %.3lf us
\n
"
,
getDuration
(
t_scatter
,
t_mm
)
*
fprintf
(
stderr
,
"GeMM time %.3lf us
\n
"
,
getDuration
(
t_scatter
,
t_mm
)
*
1e6
);
1e6
);
#endif
#endif
for
(
int
i
=
batch_size
-
1
;
i
>=
0
;
--
i
)
{
int
target_idx
=
--
expert_ptr
[
gate
[
i
]];
#ifdef MOE_DEBUG_SCATTER
fprintf
(
stderr
,
"cb idx %d gate %d tgt %d
\n
"
,
i
,
gate
[
i
],
target_idx
);
#endif
checkCudaErrors
(
cudaMemcpyAsync
(
output
+
i
*
out_feat
,
output_buf
+
target_idx
*
out_feat
,
sizeof
(
scalar_t
)
*
out_feat
,
cudaMemcpyDeviceToDevice
,
h
->
getStream
(
gate
[
i
])));
}
h
->
sync
();
h
->
sync
();
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
h
->
sync
(
0
);
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_gather
);
timestamp
(
t_gather
);
...
@@ -177,7 +195,11 @@ void moe_cuda_forward_impl(
...
@@ -177,7 +195,11 @@ void moe_cuda_forward_impl(
#endif
#endif
cudaFree
(
input_buf
);
cudaFree
(
input_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
output_buf
);
cudaFree
(
output_buf
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
gate
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment