Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ad651f03
Commit
ad651f03
authored
Mar 30, 2022
by
Rick Ho
Browse files
fix potential stream synchronization issue
parent
6c68b56b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
3 deletions
+25
-3
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+11
-0
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+6
-3
cuda/fastermoe/status.h
cuda/fastermoe/status.h
+8
-0
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
ad651f03
...
...
@@ -6,9 +6,19 @@
#include <c10/cuda/CUDAGuard.h>
#include "smart_schedule.h"
#include "status.h"
long
pipeline_gran
=
-
1
;
int
smart_sch_enabled
=
0
;
int
isSmartSchEnabled
()
{
return
smart_sch_enabled
;
}
void
setSmartSchEnabled
(
int
s
)
{
smart_sch_enabled
=
s
;
}
std
::
vector
<
torch
::
Tensor
>
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
...
...
@@ -24,6 +34,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
}
else
{
pipeline_gran
=
4
;
}
setSmartSchEnabled
(
1
);
}
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
...
...
cuda/fastermoe/smart_schedule.h
View file @
ad651f03
...
...
@@ -76,7 +76,8 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
template
<
typename
scalar_t
>
void
_compute_fn
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
CudaStreamManager
*
smgr
)
{
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
c10
::
CppTypeToScalarType
<
scalar_t
>::
value
)
.
device
(
device
)
...
...
@@ -85,7 +86,9 @@ void _compute_fn(py::function fn, c10::Device device,
{
micro_batch_size
,
d_model
},
options
);
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
smgr
->
use_default
=
true
;
fn
(
inp
,
oup
,
step
);
smgr
->
use_default
=
false
;
}
...
...
@@ -156,7 +159,7 @@ void fmoe_cuda_fused_forward_impl(
_compute_fn
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
ei
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaEventRecord
(
output_ready
[
step
],
stream
);
...
...
@@ -286,7 +289,7 @@ void fmoe_cuda_fused_backward_impl(
_compute_fn
(
backward_fn
,
device
,
global_grad_out
,
global_grad_in
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
ei
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
// TODO: get pytorch's compute stream
}
...
...
cuda/fastermoe/status.h
0 → 100644
View file @
ad651f03
#pragma once
#ifndef FASTER_STATUS_H
#define FASTER_STATUS_H
int
isSmartSchEnabled
();
void
setSmartSchEnabled
(
int
);
#endif // FASTER_STATUS_H
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment