Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
794dd0e6
Commit
794dd0e6
authored
Mar 31, 2022
by
Rick Ho
Browse files
expert shadow backward with test
parent
b5b72d41
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
148 additions
and
84 deletions
+148
-84
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+52
-5
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+67
-57
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+5
-1
fmoe/fastermoe/expert_utils.py
fmoe/fastermoe/expert_utils.py
+2
-3
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+14
-12
tests/test_faster_shadow.py
tests/test_faster_shadow.py
+8
-6
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
794dd0e6
...
@@ -19,6 +19,51 @@ void setSmartSchEnabled(int s) {
...
@@ -19,6 +19,51 @@ void setSmartSchEnabled(int s) {
smart_sch_enabled
=
s
;
smart_sch_enabled
=
s
;
}
}
inline
ncclDataType_t
getNcclDataType
(
at
::
ScalarType
t
)
{
switch
(
t
)
{
case
at
::
kChar
:
return
ncclInt8
;
case
at
::
kByte
:
return
ncclUint8
;
case
at
::
kFloat
:
return
ncclFloat
;
case
at
::
kDouble
:
return
ncclDouble
;
case
at
::
kInt
:
return
ncclInt32
;
case
at
::
kLong
:
return
ncclInt64
;
case
at
::
kHalf
:
return
ncclHalf
;
case
at
::
kBool
:
return
ncclUint8
;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
case
at
::
kBFloat16
:
return
ncclBfloat16
;
#endif
default:
return
ncclChar
;
}
}
void
_reduce_grad
(
torch
::
Tensor
t
,
long
root
,
long
expert_size
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaEvent_t
evt_stash
;
cudaEventCreate
(
&
evt_stash
);
cudaEventRecord
(
evt_stash
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
evt_stash
,
0
);
cudaEventDestroy
(
evt_stash
);
auto
dtype
=
getNcclDataType
(
t
.
scalar_type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
t
.
scalar_type
(),
"fmoe_cuda_reduce_grad"
,
([
&
]
{
void
*
buf
=
(
void
*
)
t
.
data_ptr
<
scalar_t
>
();
NCCL_SAFE_CALL
(
ncclReduce
(
buf
,
buf
,
expert_size
,
dtype
,
ncclSum
,
root
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
})
);
}
std
::
vector
<
torch
::
Tensor
>
_smart_sch_forward
(
std
::
vector
<
torch
::
Tensor
>
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
...
@@ -51,7 +96,6 @@ std::vector<torch::Tensor> _smart_sch_forward(
...
@@ -51,7 +96,6 @@ std::vector<torch::Tensor> _smart_sch_forward(
// TODO: maybe empty is faster
// TODO: maybe empty is faster
auto
global_input_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
global_input_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
global_output_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
global_output_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
std
::
vector
<
torch
::
Tensor
>
params
;
std
::
vector
<
torch
::
Tensor
>
params
;
...
@@ -96,7 +140,6 @@ torch::Tensor _smart_sch_backward(
...
@@ -96,7 +140,6 @@ torch::Tensor _smart_sch_backward(
torch
::
Tensor
stored_models
,
torch
::
Tensor
stored_models
,
long
buf_batch_size
,
long
buf_batch_size
,
long
global_batch_size
,
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
long
n_workers
,
py
::
function
backward_fn
,
py
::
function
backward_fn
,
py
::
function
stash_fn
,
py
::
function
stash_fn
,
...
@@ -116,6 +159,10 @@ torch::Tensor _smart_sch_backward(
...
@@ -116,6 +159,10 @@ torch::Tensor _smart_sch_backward(
"fmoe_cuda_smartsch_backward"
,
([
&
]
{
"fmoe_cuda_smartsch_backward"
,
([
&
]
{
fmoe_cuda_fused_backward_impl
(
fmoe_cuda_fused_backward_impl
(
backward_fn
,
backward_fn
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
,
grad_out
.
device
(),
grad_out
.
device
(),
grad_out
.
data_ptr
<
scalar_t
>
(),
grad_out
.
data_ptr
<
scalar_t
>
(),
...
@@ -129,7 +176,7 @@ torch::Tensor _smart_sch_backward(
...
@@ -129,7 +176,7 @@ torch::Tensor _smart_sch_backward(
d_model
,
num_expert
,
rank
,
n_workers
,
d_model
,
num_expert
,
rank
,
n_workers
,
pipeline_gran
,
smgr
);
pipeline_gran
,
smgr
);
}));
}));
return
{
grad_in
,}
;
return
grad_in
;
}
}
#endif
#endif
cuda/fastermoe/smart_schedule.h
View file @
794dd0e6
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
_
exchange
_w
ith
(
void
exchange
W
ith
(
const
scalar_t
*
sendbuf
,
size_t
sendcount
,
int
t_send
,
const
scalar_t
*
sendbuf
,
size_t
sendcount
,
int
t_send
,
scalar_t
*
recvbuf
,
size_t
recvcount
,
int
t_recv
,
scalar_t
*
recvbuf
,
size_t
recvcount
,
int
t_recv
,
long
d_model
,
long
d_model
,
...
@@ -40,7 +40,7 @@ void _exchange_with(
...
@@ -40,7 +40,7 @@ void _exchange_with(
int idx_self = ei + rank * num_expert;
int idx_self = ei + rank * num_expert;
void
_
compute
_p
trs
(
long
num_expert
,
long
rank
,
long
world_size
,
void
compute
P
trs
(
long
num_expert
,
long
rank
,
long
world_size
,
const
long
*
local_expert_count
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
const
long
*
global_expert_count
,
const
bool
*
stored_models
,
const
bool
*
stored_models
,
...
@@ -76,7 +76,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
...
@@ -76,7 +76,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
_
compute
_f
n
(
py
::
function
fn
,
c10
::
Device
device
,
void
compute
F
n
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
long
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
long
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
...
@@ -119,7 +119,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -119,7 +119,7 @@ void fmoe_cuda_fused_forward_impl(
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
_
compute
_p
trs
(
num_expert
,
rank
,
world_size
,
compute
P
trs
(
num_expert
,
rank
,
world_size
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_ptr
,
global_ptr
,
local_global_ptr
);
local_ptr
,
global_ptr
,
local_global_ptr
);
...
@@ -145,7 +145,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -145,7 +145,7 @@ void fmoe_cuda_fused_forward_impl(
int
rank_send
=
j
+
to_base
;
int
rank_send
=
j
+
to_base
;
int
rank_recv
=
j
+
from_base
;
int
rank_recv
=
j
+
from_base
;
GEN_IDX
;
GEN_IDX
;
_
exchange
_w
ith
(
input_buf
+
local_ptr
[
idx_send
]
*
d_model
,
exchange
W
ith
(
input_buf
+
local_ptr
[
idx_send
]
*
d_model
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
...
@@ -167,6 +167,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -167,6 +167,7 @@ void fmoe_cuda_fused_forward_impl(
cudaEventCreate
(
&
evt_get
);
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
cudaEventDestroy
(
evt_get
);
}
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
...
@@ -185,8 +186,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -185,8 +186,7 @@ void fmoe_cuda_fused_forward_impl(
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
(
from_base
+
pipeline_gran
)]
-
offset
;
computeFn
(
forward_fn
,
device
,
_compute_fn
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
global_input_buf
,
global_output_buf
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
...
@@ -200,7 +200,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -200,7 +200,7 @@ void fmoe_cuda_fused_forward_impl(
cudaStreamWaitEvent
(
torch_stream
,
evt_shadow
[
si
],
0
);
cudaStreamWaitEvent
(
torch_stream
,
evt_shadow
[
si
],
0
);
long
offset
=
local_ptr
[
i
];
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
_
compute
_f
n
(
forward_fn
,
device
,
compute
F
n
(
forward_fn
,
device
,
input_buf
,
output_buf
,
input_buf
,
output_buf
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
++
si
;
++
si
;
...
@@ -218,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -218,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
int
rank_send
=
j
+
from_base
;
int
rank_send
=
j
+
from_base
;
int
rank_recv
=
j
+
to_base
;
int
rank_recv
=
j
+
to_base
;
GEN_IDX
;
GEN_IDX
;
_
exchange
_w
ith
(
global_output_buf
+
global_ptr
[
gidx_send
]
*
d_model
,
exchange
W
ith
(
global_output_buf
+
global_ptr
[
gidx_send
]
*
d_model
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
...
@@ -241,15 +241,16 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -241,15 +241,16 @@ void fmoe_cuda_fused_forward_impl(
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
if
(
params
.
size
())
{
delete
[]
evt_shadow
;
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_backward_impl
(
void
fmoe_cuda_fused_backward_impl
(
py
::
function
backward_fn
,
py
::
function
backward_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
,
py
::
function
collect_fn
,
py
::
function
set_grad_fn
,
c10
::
Device
device
,
c10
::
Device
device
,
scalar_t
*
grad_out
,
scalar_t
*
grad_out
,
...
@@ -269,10 +270,9 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -269,10 +270,9 @@ void fmoe_cuda_fused_backward_impl(
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
_
compute
_p
trs
(
num_expert
,
rank
,
world_size
,
compute
P
trs
(
num_expert
,
rank
,
world_size
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_ptr
,
global_ptr
,
local_global_ptr
);
local_ptr
,
global_ptr
,
local_global_ptr
);
if
(
pipeline_gran
>
world_size
)
{
if
(
pipeline_gran
>
world_size
)
{
pipeline_gran
=
world_size
;
pipeline_gran
=
world_size
;
}
}
...
@@ -286,6 +286,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -286,6 +286,7 @@ void fmoe_cuda_fused_backward_impl(
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
}
}
// S_0 ... S_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
...
@@ -294,7 +295,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -294,7 +295,7 @@ void fmoe_cuda_fused_backward_impl(
int
rank_send
=
j
+
to_base
;
int
rank_send
=
j
+
to_base
;
int
rank_recv
=
j
+
from_base
;
int
rank_recv
=
j
+
from_base
;
GEN_IDX
;
GEN_IDX
;
_
exchange
_w
ith
(
grad_out
+
local_ptr
[
idx_send
]
*
d_model
,
exchange
W
ith
(
grad_out
+
local_ptr
[
idx_send
]
*
d_model
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
...
@@ -305,6 +306,27 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -305,6 +306,27 @@ void fmoe_cuda_fused_backward_impl(
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
}
}
// Shadowed experts backward and reduce
cudaEvent_t
*
evt_reduce
=
new
cudaEvent_t
[
num_expert
];
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
stash_fn
(
si
);
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
backward_fn
,
device
,
grad_out
,
grad_in
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
collect_fn
(
si
,
i
/
num_expert
);
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
}
++
si
;
}
}
pop_fn
();
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
input_ready
[
step
],
0
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
input_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
...
@@ -313,13 +335,25 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -313,13 +335,25 @@ void fmoe_cuda_fused_backward_impl(
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
(
from_base
+
pipeline_gran
)]
-
offset
;
_
compute
_f
n
(
backward_fn
,
device
,
compute
F
n
(
backward_fn
,
device
,
global_grad_out
,
global_grad_in
,
global_grad_out
,
global_grad_in
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
}
// Collect gradients for shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
cudaStreamWaitEvent
(
torch_stream
,
evt_reduce
[
i
%
num_expert
],
0
);
set_grad_fn
(
si
);
}
++
si
;
}
}
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
],
0
);
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
...
@@ -329,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -329,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
int
rank_send
=
j
+
from_base
;
int
rank_send
=
j
+
from_base
;
int
rank_recv
=
j
+
to_base
;
int
rank_recv
=
j
+
to_base
;
GEN_IDX
;
GEN_IDX
;
_
exchange
_w
ith
(
global_grad_in
+
global_ptr
[
gidx_send
]
*
d_model
,
exchange
W
ith
(
global_grad_in
+
global_ptr
[
gidx_send
]
*
d_model
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
...
@@ -341,36 +375,6 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -341,36 +375,6 @@ void fmoe_cuda_fused_backward_impl(
checkCudaErrors
(
cudaGetLastError
());
checkCudaErrors
(
cudaGetLastError
());
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_backward(
original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
smgr->stream(stream), smgr->handle(stream));
}
}
*/
delete
[]
local_ptr
;
delete
[]
local_ptr
;
delete
[]
global_ptr
;
delete
[]
global_ptr
;
delete
[]
local_global_ptr
;
delete
[]
local_global_ptr
;
...
@@ -381,6 +385,12 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -381,6 +385,12 @@ void fmoe_cuda_fused_backward_impl(
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
for
(
long
i
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
stored_models
[
i
+
rank
*
num_expert
])
{
cudaEventDestroy
(
evt_reduce
[
i
]);
}
}
delete
[]
evt_reduce
;
}
}
#endif // SMART_SCHEDULE_H
#endif // SMART_SCHEDULE_H
cuda/fmoe_cuda.cpp
View file @
794dd0e6
...
@@ -77,13 +77,16 @@ torch::Tensor _smart_sch_backward(
...
@@ -77,13 +77,16 @@ torch::Tensor _smart_sch_backward(
torch
::
Tensor
stored_models
,
torch
::
Tensor
stored_models
,
long
buf_batch_size
,
long
buf_batch_size
,
long
global_batch_size
,
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
long
n_workers
,
py
::
function
backward_fn
,
py
::
function
backward_fn
,
py
::
function
stash_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
,
py
::
function
pop_fn
,
py
::
function
collect_fn
,
py
::
function
collect_fn
,
py
::
function
set_grad_fn
);
py
::
function
set_grad_fn
);
void
_reduce_grad
(
torch
::
Tensor
t
,
long
root
,
long
expert_size
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
#ifdef FMOE_USE_NCCL
...
@@ -95,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -95,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"smart_sch_forward"
,
&
_smart_sch_forward
,
"E2E MoE layer forward with smart scheduling"
);
m
.
def
(
"smart_sch_forward"
,
&
_smart_sch_forward
,
"E2E MoE layer forward with smart scheduling"
);
m
.
def
(
"smart_sch_backward"
,
&
_smart_sch_backward
,
"E2E MoE layer backward with smart scheduling"
);
m
.
def
(
"smart_sch_backward"
,
&
_smart_sch_backward
,
"E2E MoE layer backward with smart scheduling"
);
m
.
def
(
"reduce_grad"
,
&
_reduce_grad
,
"Reduce gradients over FastMoE's communication stream"
);
#endif
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
...
...
fmoe/fastermoe/expert_utils.py
View file @
794dd0e6
...
@@ -6,7 +6,6 @@ def get_expert_param_size(e):
...
@@ -6,7 +6,6 @@ def get_expert_param_size(e):
def
get_expert_params
(
e
,
out
):
def
get_expert_params
(
e
,
out
):
print
(
'gep to {}'
.
format
(
out
))
offset
=
0
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
...
@@ -42,7 +41,7 @@ def collect_expert_grads(e, grads):
...
@@ -42,7 +41,7 @@ def collect_expert_grads(e, grads):
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
offset
+=
p
.
numel
()
if
p
.
grad
is
not
None
:
if
p
.
grad
is
not
None
:
seg
.
copy_
(
p
.
grad
)
seg
.
copy_
(
p
.
grad
.
flatten
()
)
p
.
grad
=
None
p
.
grad
=
None
else
:
else
:
seg
.
zero_
()
seg
.
zero_
()
...
@@ -56,4 +55,4 @@ def set_grads(e, grads):
...
@@ -56,4 +55,4 @@ def set_grads(e, grads):
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
p
.
grad
=
seg
.
clone
()
p
.
grad
=
seg
.
clone
()
else
:
else
:
p
.
grad
+=
seg
p
.
grad
+=
seg
.
reshape
(
p
.
shape
)
fmoe/fastermoe/schedule.py
View file @
794dd0e6
...
@@ -24,15 +24,14 @@ class MoEForward(Function):
...
@@ -24,15 +24,14 @@ class MoEForward(Function):
world_size
):
world_size
):
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx
.
gibs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gibs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
2
)
def
_expert_forward
(
x
,
y
,
idx
):
def
_expert_forward
(
x
,
y
,
idx
):
nothing
=
lambda
a
:
a
x
=
x
.
data
x
=
x
.
data
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
x
.
requires_grad
=
True
x
.
requires_grad
=
True
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
nothing
,
nothing
):
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
ctx
.
gibs
[
idx
]
=
x
ctx
.
gibs
[
idx
]
=
x
ctx
.
gobs
[
idx
]
=
y0
ctx
.
gobs
[
idx
]
=
y0
...
@@ -60,7 +59,7 @@ class MoEForward(Function):
...
@@ -60,7 +59,7 @@ class MoEForward(Function):
maybe_overlap
=
False
)
maybe_overlap
=
False
)
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
gib
)
stored_models
,
gib
,
local_input_buf
)
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
...
@@ -70,30 +69,33 @@ class MoEForward(Function):
...
@@ -70,30 +69,33 @@ class MoEForward(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
_
)
=
ctx
.
saved_tensors
stored_models
,
_
1
,
_2
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad_y
,
grad_x
,
idx
):
def
_expert_backward
(
grad_y
,
grad_x
,
idx
):
y
=
ctx
.
gobs
[
idx
]
y
=
ctx
.
gobs
[
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad_y
])
x
=
ctx
.
gibs
[
idx
]
x
=
ctx
.
gibs
[
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad_y
])
grad_x
.
copy_
(
x
.
grad
)
grad_x
.
copy_
(
x
.
grad
)
experts
=
ctx
.
experts
experts
=
ctx
.
experts
def
stash_fn
(
idx
):
def
stash_fn
(
idx
):
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
idx
])
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
idx
])
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
collect_fn
=
lambda
g
:
expert_utils
.
collect_expert_grads
(
experts
,
g
)
def
collect_fn
(
idx
,
root
):
set_grad_fn
=
lambda
g
:
expert_utils
.
set_grads
(
experts
,
g
)
grad
=
ctx
.
shadows
[
idx
]
expert_utils
.
collect_expert_grads
(
experts
,
grad
)
fmoe_native
.
reduce_grad
(
grad
,
root
,
ctx
.
expert_size
)
set_grad_fn
=
lambda
idx
:
expert_utils
.
set_grads
(
experts
,
ctx
.
shadows
[
idx
])
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
grad_out_buf
,
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
stored_models
,
stored_models
,
pos_s
.
shape
[
0
],
fwd_batch_size
,
ctx
.
expert_size
,
pos_s
.
shape
[
0
],
fwd_batch_size
,
world_size
,
_expert_backward
,
world_size
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
_expert_backward
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
...
...
tests/test_faster_shadow.py
View file @
794dd0e6
...
@@ -62,16 +62,18 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
...
@@ -62,16 +62,18 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
dist
.
broadcast
(
stored_models
,
0
)
dist
.
broadcast
(
stored_models
,
0
)
stored_models
=
stored_models
.
cpu
()
stored_models
=
stored_models
.
cpu
()
# if rank == 0:
# print('stored models {}'.format(stored_models))
ensure_comm
(
x1
,
None
)
ensure_comm
(
x1
,
None
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
stored_models
=
stored_models
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
stored_models
=
stored_models
)
#
y1.sum().backward()
y1
.
sum
().
backward
()
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
# y2.sum().backward()
y2
.
sum
().
backward
()
_assert_numerical
([
'out'
],
[
y1
],
[
y2
],
rank
)
_assert_numerical
([
'out'
,
'grad_in'
,
'grad_bias'
,
'grad_weight'
],
# _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
[
y1
,
x1
.
grad
,
m1
.
bias
.
grad
,
m1
.
weight
.
grad
],
# [y1, x1.grad, m1.bias.grad, m1.weight.grad],
[
y2
,
x2
.
grad
,
m2
.
bias
.
grad
,
m2
.
weight
.
grad
],
rank
)
# [y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment