Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
25e9324f
Commit
25e9324f
authored
Dec 28, 2022
by
Rick Ho
Browse files
cudaStreamWaitEvent compat
parent
c7e6a3db
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
8 deletions
+13
-8
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+1
-1
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+12
-7
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
25e9324f
...
...
@@ -48,7 +48,7 @@ void _reduce_grad(
cudaEvent_t
evt_stash
;
cudaEventCreate
(
&
evt_stash
);
cudaEventRecord
(
evt_stash
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
evt_stash
,
0
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_stash
);
cudaEventDestroy
(
evt_stash
);
auto
dtype
=
getNcclDataType
(
t
.
scalar_type
());
...
...
cuda/fastermoe/smart_schedule.h
View file @
25e9324f
...
...
@@ -11,6 +11,11 @@
#include "../stream_manager.h"
#if defined(CUDA_VERSION) && (CUDA_VERSION < 110010)
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__,0)
#else
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__)
#endif
template
<
typename
scalar_t
>
void
exchangeWith
(
...
...
@@ -169,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
FMOE_SWE
(
smgr
->
stream
(
1
),
evt_get
);
cudaEventDestroy
(
evt_get
);
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
...
...
@@ -183,7 +188,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
torch_stream
,
input_ready
[
step
]
,
0
);
FMOE_SWE
(
torch_stream
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
...
@@ -200,7 +205,7 @@ void fmoe_cuda_fused_forward_impl(
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
stash_fn
(
params
[
si
],
si
);
cudaStreamWaitEvent
(
torch_stream
,
evt_shadow
[
si
]
,
0
);
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
forward_fn
,
device
,
...
...
@@ -213,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
]
,
0
);
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
...
@@ -331,7 +336,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
input_ready
[
step
]
,
0
);
FMOE_SWE
(
smgr
->
stream
(
1
),
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
...
@@ -349,7 +354,7 @@ void fmoe_cuda_fused_backward_impl(
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
cudaStreamWaitEvent
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]
,
0
);
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
);
}
++
si
;
...
...
@@ -358,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
]
,
0
);
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment