Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a807e2a3
Commit
a807e2a3
authored
Mar 29, 2022
by
Rick Ho
Browse files
backward bugous on grad weight
parent
46c3722d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
89 deletions
+71
-89
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+15
-49
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+8
-20
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+10
-0
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+9
-8
tests/test_faster_schedule.py
tests/test_faster_schedule.py
+29
-12
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
a807e2a3
...
@@ -17,7 +17,6 @@ torch::Tensor _smart_sch_forward(
...
@@ -17,7 +17,6 @@ torch::Tensor _smart_sch_forward(
long
global_batch_size
,
long
global_batch_size
,
long
n_workers
,
long
n_workers
,
py
::
function
forward_fn
)
{
py
::
function
forward_fn
)
{
if
(
pipeline_gran
==
-
1
)
{
if
(
pipeline_gran
==
-
1
)
{
char
*
p
=
getenv
(
"FMOE_FASTER_GROUP_SIZE"
);
char
*
p
=
getenv
(
"FMOE_FASTER_GROUP_SIZE"
);
if
(
p
)
{
if
(
p
)
{
...
@@ -40,7 +39,7 @@ torch::Tensor _smart_sch_forward(
...
@@ -40,7 +39,7 @@ torch::Tensor _smart_sch_forward(
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"fmoe_cuda_
fused
_forward"
,
([
&
]
{
"fmoe_cuda_
smart_sch
_forward"
,
([
&
]
{
fmoe_cuda_fused_forward_impl
(
fmoe_cuda_fused_forward_impl
(
forward_fn
,
forward_fn
,
input_buf
.
device
(),
input_buf
.
device
(),
...
@@ -59,75 +58,42 @@ torch::Tensor _smart_sch_forward(
...
@@ -59,75 +58,42 @@ torch::Tensor _smart_sch_forward(
return
output_buf
;
return
output_buf
;
}
}
/*
torch
::
Tensor
_smart_sch_backward
(
std::vector<torch::Tensor> _fused_backward(
torch::Tensor input_buf,
std::vector<std::vector<std::vector<torch::Tensor>>> params,
torch::Tensor middle_buf,
torch::Tensor output_buf,
torch
::
Tensor
grad_out
,
torch
::
Tensor
grad_out
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
torch::Tensor inp,
torch
::
Tensor
stored_models
,
torch
::
Tensor
stored_models
,
long global_batch_size,
long
buf_batch_size
,
long
buf_batch_size
,
long n_workers, bool has_bias) {
long
global_batch_size
,
long
n_workers
,
py
::
function
backward_fn
)
{
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
smgr
=
getCudaStreamManager
(
grad_out
.
device
().
index
());
auto smgr = getCudaStreamManager(input_buf.device().index());
int
rank
;
int
rank
;
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
);
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
);
const
auto
d_model
=
grad_out
.
size
(
1
);
const auto d_hidden = params[rank][0][0].size(1);
auto
global_grad_out
=
grad_out
.
new_zeros
({
global_batch_size
,
d_model
});
const auto d_model = params[rank][0][0].size(2);
auto
global_grad_in
=
grad_out
.
new_zeros
({
global_batch_size
,
d_model
});
auto
grad_in
=
grad_out
.
new_zeros
({
buf_batch_size
,
d_model
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad_out
.
scalar_type
(),
auto global_grad_out = input_buf.new_zeros({global_batch_size, d_model});
"fmoe_cuda_smartsch_backward"
,
([
&
]
{
auto grad_middle = input_buf.new_zeros({global_batch_size, d_hidden});
auto global_grad_in = input_buf.new_zeros({global_batch_size, d_model});
auto grad_in = input_buf.new_zeros({buf_batch_size, d_model});
for (auto node : params)
for (auto expert : node)
for (int i = 0; i < expert.size(); i++) {
// create the respective gradient of each tensor
CHECK_INPUT(expert[i]);
if (expert[i].grad().defined()) {
CHECK_INPUT(expert[i].grad());
continue;
}
expert[i].mutable_grad() = input_buf.new_zeros(expert[i].sizes());
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_backward", ([&] {
fmoe_cuda_fused_backward_impl
(
fmoe_cuda_fused_backward_impl
(
input_buf.data_ptr<scalar_t>(),
backward_fn
,
inp.data_ptr<scalar_t>(),
grad_out
.
device
(),
params,
middle_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
grad_out
.
data_ptr
<
scalar_t
>
(),
grad_out
.
data_ptr
<
scalar_t
>
(),
global_grad_out
.
data_ptr
<
scalar_t
>
(),
global_grad_out
.
data_ptr
<
scalar_t
>
(),
global_grad_in
.
data_ptr
<
scalar_t
>
(),
global_grad_in
.
data_ptr
<
scalar_t
>
(),
grad_middle.data_ptr<scalar_t>(),
grad_in
.
data_ptr
<
scalar_t
>
(),
grad_in
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
stored_models
.
data_ptr
<
bool
>
(),
stored_models
.
data_ptr
<
bool
>
(),
d_model,
d_hidden,
num_expert, rank, n_workers,
has_bias,
d_model
,
num_expert
,
rank
,
n_workers
,
pipeline_gran
,
smgr
);
pipeline_gran
,
smgr
);
}));
}));
return
{
grad_in
,};
return
{
grad_in
,};
}
}
*/
#endif
#endif
cuda/fastermoe/smart_schedule.h
View file @
a807e2a3
...
@@ -74,7 +74,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
...
@@ -74,7 +74,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
_compute_f
orward
(
py
::
function
fn
,
c10
::
Device
device
,
void
_compute_f
n
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
auto
options
=
torch
::
TensorOptions
()
auto
options
=
torch
::
TensorOptions
()
...
@@ -89,14 +89,6 @@ void _compute_forward(py::function fn, c10::Device device,
...
@@ -89,14 +89,6 @@ void _compute_forward(py::function fn, c10::Device device,
}
}
template
<
typename
scalar_t
>
void
_compute_backward
(
py
::
function
fn
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
long
*
local_expert_count
,
long
*
global_expert_count
,
int
ei
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_forward_impl
(
void
fmoe_cuda_fused_forward_impl
(
py
::
function
forward_fn
,
py
::
function
forward_fn
,
...
@@ -162,7 +154,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -162,7 +154,7 @@ void fmoe_cuda_fused_forward_impl(
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
(
from_base
+
pipeline_gran
)]
-
offset
;
_compute_f
orward
(
forward_fn
,
device
,
_compute_f
n
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
global_input_buf
,
global_output_buf
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
}
}
...
@@ -230,19 +222,17 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -230,19 +222,17 @@ void fmoe_cuda_fused_forward_impl(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_backward_impl
(
void
fmoe_cuda_fused_backward_impl
(
py
::
function
backward_fn
,
py
::
function
backward_fn
,
const
scalar_t
*
input_buf
,
c10
::
Device
device
,
const
scalar_t
*
output_buf
,
const
scalar_t
*
grad_out
,
scalar_t
*
grad_out
,
scalar_t
*
global_grad_out
,
scalar_t
*
global_grad_out
,
scalar_t
*
global_grad_in
,
scalar_t
*
global_grad_in
,
scalar_t
*
grad_in
,
scalar_t
*
grad_in
,
const
long
*
local_expert_count
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
const
long
*
global_expert_count
,
const
bool
*
stored_models
,
const
bool
*
stored_models
,
long
d_model
,
long
d_hidden
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
...
@@ -294,11 +284,9 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -294,11 +284,9 @@ void fmoe_cuda_fused_backward_impl(
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
(
from_base
+
pipeline_gran
)]
-
offset
;
_compute_backward
(
backward_fn
,
_compute_fn
(
backward_fn
,
device
,
input_buf
,
output_buf
,
global_grad_out
,
global_grad_out
,
global_grad_in
,
global_grad_in
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
ei
,
offset
,
micro_batch_size
);
}
}
// TODO: get pytorch's compute stream
// TODO: get pytorch's compute stream
}
}
...
...
cuda/fmoe_cuda.cpp
View file @
a807e2a3
...
@@ -65,6 +65,15 @@ torch::Tensor _smart_sch_forward(
...
@@ -65,6 +65,15 @@ torch::Tensor _smart_sch_forward(
torch
::
Tensor
stored_models
,
torch
::
Tensor
stored_models
,
long
global_batch_size
,
long
n_workers
,
long
global_batch_size
,
long
n_workers
,
py
::
function
forward_fn
);
py
::
function
forward_fn
);
torch
::
Tensor
_smart_sch_backward
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
stored_models
,
long
buf_batch_size
,
long
global_batch_size
,
long
n_workers
,
py
::
function
backward_fn
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
#ifdef FMOE_USE_NCCL
...
@@ -75,6 +84,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -75,6 +84,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"swipe_once"
,
&
_swipe_once
,
"SWIPE balance strategy(CUDA)"
);
m
.
def
(
"swipe_once"
,
&
_swipe_once
,
"SWIPE balance strategy(CUDA)"
);
m
.
def
(
"smart_sch_forward"
,
&
_smart_sch_forward
,
"E2E MoE layer forward with smart scheduling"
);
m
.
def
(
"smart_sch_forward"
,
&
_smart_sch_forward
,
"E2E MoE layer forward with smart scheduling"
);
m
.
def
(
"smart_sch_backward"
,
&
_smart_sch_backward
,
"E2E MoE layer backward with smart scheduling"
);
#endif
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
...
...
fmoe/fastermoe/schedule.py
View file @
a807e2a3
...
@@ -29,8 +29,9 @@ class MoEForward(Function):
...
@@ -29,8 +29,9 @@ class MoEForward(Function):
ctx
.
gobs
=
[
None
]
*
world_size
ctx
.
gobs
=
[
None
]
*
world_size
def
_expert_forward
(
x
,
y
,
idx
):
def
_expert_forward
(
x
,
y
,
idx
):
x
=
x
.
data
x
=
x
.
data
x
.
requires_grad
=
True
with
torch
.
enable_grad
():
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
x
.
requires_grad
=
True
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
ctx
.
gibs
[
idx
]
=
x
ctx
.
gibs
[
idx
]
=
x
ctx
.
gobs
[
idx
]
=
y0
ctx
.
gobs
[
idx
]
=
y0
y
.
copy_
(
y0
)
y
.
copy_
(
y0
)
...
@@ -55,21 +56,21 @@ class MoEForward(Function):
...
@@ -55,21 +56,21 @@ class MoEForward(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
gib
,
gmb
,
gob
,
stored_models
)
=
ctx
.
saved_tensors
stored_models
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad
,
idx
):
def
_expert_backward
(
grad
_y
,
grad_x
,
idx
):
y
=
ctx
.
gobs
[
idx
]
y
=
ctx
.
gobs
[
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad
])
torch
.
autograd
.
backward
([
y
],
[
grad
_y
])
x
=
ctx
.
gibs
[
idx
]
x
=
ctx
.
gibs
[
idx
]
return
x
.
grad
grad_x
.
copy_
(
x
.
grad
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
gib
,
gmb
,
gob
,
grad_out_buf
,
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
stored_models
,
stored_models
,
fwd_batch_size
,
pos_s
.
shape
[
0
]
,
pos_s
.
shape
[
0
],
fwd_batch_size
,
world_size
,
_expert_backward
)
world_size
,
_expert_backward
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
...
...
tests/test_faster_schedule.py
View file @
a807e2a3
...
@@ -19,7 +19,8 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
...
@@ -19,7 +19,8 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
def
test_faster_schedule
(
n_process
,
d_model
,
batch_size
,
n_expert
):
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
def
test_faster_schedule
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
):
_run_distributed
(
'_test_faster_schedule'
,
_run_distributed
(
'_test_faster_schedule'
,
n_process
,
n_process
,
{
{
...
@@ -28,7 +29,9 @@ def test_faster_schedule(n_process, d_model, batch_size, n_expert):
...
@@ -28,7 +29,9 @@ def test_faster_schedule(n_process, d_model, batch_size, n_expert):
'n_expert'
:
n_expert
'n_expert'
:
n_expert
},
},
script
=
__file__
,
script
=
__file__
,
env
=
dict
()
env
=
dict
(
FMOE_FASTER_GROUP_SIZE
=
str
(
group_sz
)
)
)
)
...
@@ -37,19 +40,33 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
...
@@ -37,19 +40,33 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x1
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
.
requires_grad
=
True
x1
.
requires_grad
=
True
x2
=
x1
.
data
.
clone
()
x2
.
requires_grad
=
True
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
m
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
m1
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
m2
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
with
torch
.
no_grad
():
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
bias
.
copy_
(
m1
.
bias
)
def
e
xpert_fn
(
x
,
fec
):
def
e
f1
(
x
,
fec
):
y
=
m
(
x
)
y
=
m
1
(
x
)
return
y
return
y
def
ef2
(
x
,
fec
):
y
=
m2
(
x
)
return
y
ensure_comm
(
x1
,
None
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
)
y1
.
sum
().
backward
()
ensure_comm
(
x
,
None
)
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
)
y
=
smart_fwd
(
x
,
topk_idx
,
expert_fn
,
n_expert
,
world_size
)
y2
.
sum
().
backward
()
z
=
naive_fwd
(
x
,
topk_idx
,
expert_fn
,
n_expert
,
world_size
)
_assert_numerical
([
'out'
,
'grad_in'
,
'grad_bias'
,
'grad_weight'
],
_assert_numerical
([
'out'
],
[
y
],
[
z
],
rank
)
[
y1
,
x1
.
grad
,
m1
.
bias
.
grad
,
m1
.
weight
.
grad
],
[
y2
,
x2
.
grad
,
m2
.
bias
.
grad
,
m2
.
weight
.
grad
],
rank
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -57,5 +74,5 @@ if __name__ == '__main__':
...
@@ -57,5 +74,5 @@ if __name__ == '__main__':
args
=
json
.
loads
(
sys
.
argv
[
2
])
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
# test_faster_schedule(8, 16, 16, 1)
# test_faster_schedule(8, 16, 16, 1
, 2
)
_test_faster_schedule
(
4
,
2
,
1
)
_test_faster_schedule
(
4
,
2
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment