Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
5680c599
Commit
5680c599
authored
May 20, 2021
by
Rich Ho
Browse files
Merge branch 'master' into laekov/gate
parents
90c4bccf
3c42c892
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
250 additions
and
117 deletions
+250
-117
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+12
-8
cuda/parallel_linear.cu
cuda/parallel_linear.cu
+37
-30
cuda/parallel_linear.cuh
cuda/parallel_linear.cuh
+74
-3
fmoe/functions.py
fmoe/functions.py
+11
-7
fmoe/layers.py
fmoe/layers.py
+1
-31
setup.py
setup.py
+1
-1
tests/test_ddp.py
tests/test_ddp.py
+7
-4
tests/test_gates.py
tests/test_gates.py
+49
-13
tests/test_local_exchange.py
tests/test_local_exchange.py
+3
-3
tests/test_numerical.py
tests/test_numerical.py
+20
-11
tests/test_zero.py
tests/test_zero.py
+35
-6
No files found.
cuda/fmoe_cuda.cpp
View file @
5680c599
...
...
@@ -30,13 +30,17 @@ void _assign_pos(
// parallel_linear
std
::
vector
<
torch
::
Tensor
>
_linear_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
std
::
vector
<
torch
::
Tensor
>
_linear_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
// balancing
std
::
vector
<
torch
::
Tensor
>
_limit_by_capacity
(
...
...
cuda/parallel_linear.c
pp
→
cuda/parallel_linear.c
u
View file @
5680c599
#include "parallel_linear.h"
#include "parallel_linear.
cu
h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std
::
vector
<
torch
::
Tensor
>
_linear_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
at
::
optional
<
torch
::
Tensor
>
bias
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef
F
MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
#endif
torch
::
Tensor
output
;
if
(
bias
.
has_value
())
{
output
=
bias
.
value
().
repeat_interleave
(
expert_count
.
to
(
bias
.
value
().
device
()),
0
);
}
else
{
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"
f
moe_
linear_
forward"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward
_cuda
"
,
([
&
]
{
fmoe_cuda_forward_impl
<
scalar_t
>
(
fmoe_cuda_
linear_
forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
in_feat
,
out_feat
,
num_expert
,
...
...
@@ -42,23 +48,21 @@ std::vector<torch::Tensor> _linear_forward(
return
{
output
,
};
}
std
::
vector
<
torch
::
Tensor
>
_linear_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
at
::
optional
<
torch
::
Tensor
>
bias
)
{
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef
F
MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
...
...
@@ -66,15 +70,18 @@ std::vector<torch::Tensor> _linear_backward(
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"
ff
moe_
linear
_backward"
,
([
&
]
{
fmoe_cuda_backward_impl
<
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_
cuda
_backward"
,
([
&
]
{
fmoe_cuda_
linear_
backward_impl
<
scalar_t
>
(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_bias
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
batch_size
,
in_feat
,
out_feat
,
...
...
@@ -83,6 +90,6 @@ std::vector<torch::Tensor> _linear_backward(
);
}));
return
{
grad_input_buf
,
grad_weight
};
return
{
grad_input_buf
,
grad_weight
,
grad_bias
};
}
cuda/parallel_linear.h
→
cuda/parallel_linear.
cu
h
View file @
5680c599
...
...
@@ -2,17 +2,68 @@
#include "utils/cublas_wrapper.h"
/*
This function is to be called with one block per each column
*/
template
<
typename
scalar_t
>
void
fmoe_cuda_forward_impl
(
__global__
void
column_reduce
(
const
scalar_t
*
matrix
,
scalar_t
*
result
,
int
m
/* lines */
,
int
n
/* columns*/
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
unsigned
char
my_smem
[];
scalar_t
*
sdata
=
reinterpret_cast
<
scalar_t
*>
(
my_smem
);
// normal tid
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
// transposed tid for shared memory
int
new_tid
=
threadIdx
.
y
+
threadIdx
.
x
*
blockDim
.
y
;
// true x value in the matrix
int
real_x
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
i
=
real_x
+
n
*
threadIdx
.
y
;
const
int
it
=
n
*
blockDim
.
y
;
int
offset
=
it
;
float
accumulator
=
0
;
if
(
threadIdx
.
y
<
m
&&
real_x
<
n
)
{
// store all the values from this column in a warped way
accumulator
=
matrix
[
i
];
while
(
i
+
offset
<
n
*
m
)
{
accumulator
+=
matrix
[
i
+
offset
];
offset
+=
it
;
}
}
// save column reduction data in a transposed way
sdata
[
new_tid
]
=
accumulator
;
__syncthreads
();
for
(
size_t
t
=
16
;
t
>
0
;
t
>>=
1
)
{
if
(
tid
<
32
*
32
-
16
)
sdata
[
tid
]
+=
sdata
[
tid
+
t
];
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
&&
real_x
<
n
)
result
[
real_x
]
=
sdata
[
new_tid
];
}
template
<
typename
scalar_t
>
void
fmoe_cuda_linear_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
const
bool
has_bias
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
...
...
@@ -37,13 +88,15 @@ void fmoe_cuda_forward_impl(
}
template
<
typename
scalar_t
>
void
fmoe_cuda_backward_impl
(
void
fmoe_cuda_
linear_
backward_impl
(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_bias
,
const
bool
has_bias
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
...
...
@@ -51,10 +104,16 @@ void fmoe_cuda_backward_impl(
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
// bias
dim3
block_threads
(
32
,
32
);
dim3
grid_threads
(
out_feat
/
32
+
(
out_feat
%
32
?
1
:
0
),
1
);
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
cudaMemset
(
grad_bias
+
i
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
out_feat
);
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
...
...
@@ -85,7 +144,19 @@ void fmoe_cuda_backward_impl(
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
if
(
has_bias
)
{
column_reduce
<<<
grid_threads
,
block_threads
,
sizeof
(
scalar_t
)
*
1024
,
smgr
->
stream
(
0
)
>>>
(
grad_output_buf
+
ptr
*
out_feat
,
grad_bias
+
i
*
out_feat
,
expert_count
[
i
],
out_feat
);
}
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
}
fmoe/functions.py
View file @
5680c599
...
...
@@ -147,21 +147,25 @@ class MOELinear(Function):
"""
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
(
global_output_buf
,)
=
fmoe_cuda
.
linear_forward
(
global_input_buf
,
weight
,
fwd_expert_count
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
variables
=
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
linear_backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
(
input_buf
,
fwd_expert_count
,
weight
,
bias
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
,
grad_bias
=
fmoe_cuda
.
linear_backward
(
grad_out
,
input_buf
,
fwd_expert_count
,
weight
,
bias
)
return
grad_inp_buf
,
grad_weight
,
None
if
not
torch
.
is_tensor
(
bias
):
grad_bias
=
None
return
grad_inp_buf
,
None
,
grad_weight
,
grad_bias
class
MOEGather
(
Function
):
...
...
fmoe/layers.py
View file @
5680c599
...
...
@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r
"""
Call MOE function
"""
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
,
self
.
bias
)
return
x
def
extra_repr
(
self
)
->
str
:
...
...
setup.py
View file @
5680c599
...
...
@@ -29,7 +29,7 @@ if __name__ == '__main__':
'cuda/local_exchange.cu'
,
'cuda/balancing.cu'
,
'cuda/global_exchange.cpp'
,
'cuda/parallel_linear.c
pp
'
,
'cuda/parallel_linear.c
u
'
,
'cuda/fmoe_cuda.cpp'
,
],
extra_compile_args
=
{
...
...
tests/test_ddp.py
View file @
5680c599
...
...
@@ -11,7 +11,7 @@ from test_numerical import test_fmoe_linear as _test_fmoe_linear
from
test_numerical
import
_test_fmoe_local_ddp
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
):
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
):
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"No enough GPU"
)
import
subprocess
...
...
@@ -25,7 +25,7 @@ def _run_distributed(func, world_size, args: Dict):
for
i
in
range
(
world_size
):
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
p
=
subprocess
.
Popen
(
[
sys
.
executable
,
__file__
,
func
,
json
.
dumps
(
args
)],
stdout
=
subprocess
.
PIPE
[
sys
.
executable
,
script
,
func
,
json
.
dumps
(
args
)],
stdout
=
subprocess
.
PIPE
)
ps
.
append
(
p
)
...
...
@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
,
data_type
):
_run_distributed
(
"_test_fmoe_linear"
,
...
...
@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model"
:
d_model
,
"d_hidden"
:
d_hidden
,
"mp_size"
:
mp_size
,
"data_type"
:
data_type
},
)
...
...
@@ -120,5 +122,6 @@ if __name__ == "__main__":
else
:
test_fmoe_local_ddp
(
mp_size
=
2
)
test_fmoe_linear_distributed
(
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
,
data_type
=
"torch.HalfTensor"
)
tests/test_gates.py
View file @
5680c599
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
test_ddp
import
_run_distributed
def
_ensure_initialized
():
...
...
@@ -16,14 +21,27 @@ def _ensure_initialized():
dist
.
init_process_group
(
backend
=
"nccl"
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
8
,
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
5
,
1.1
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
1.1
])
def
test_gshard_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
if
dist
.
get_world_size
()
*
n_expert
<
2
:
if
1
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
_run_distributed
(
'_test_gshard_gate'
,
1
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
,
'cap'
:
cap
},
script
=
__file__
)
def
_test_gshard_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
...
...
@@ -37,11 +55,24 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert
(
i
<=
real_cap
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
8
,
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
5
,
1.1
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
16
])
@
pytest
.
mark
.
parametrize
(
"cap"
,
[.
1
,
.
8
])
def
test_switch_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_run_distributed
(
'_test_switch_gate'
,
1
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
,
'cap'
:
cap
},
script
=
__file__
)
def
_test_switch_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
gate
=
SwitchGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
...
...
@@ -57,6 +88,11 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
_ensure_initialized
()
test_gshard_gate
(
4096
,
1024
,
4
,
.
2
)
# test_gshard_gate(4096, 1024, 4, .2)
test_gshard_gate
(
8
,
16
,
1
,
.
1
)
# test_switch_gate(4096, 1024, 4, .2)
tests/test_local_exchange.py
View file @
5680c599
...
...
@@ -10,7 +10,7 @@ import numpy as np
from
copy
import
deepcopy
from
fmoe.functions
import
MOEGather
,
MOEScatter
,
count_by_gate
from
test_numerical
import
_assert_numer
c
ial
from
test_numerical
import
_assert_numeri
c
al
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
2
])
...
...
@@ -30,10 +30,10 @@ def test_scatter(n_expert, topk, batch_size, d_model, world_size):
inp_raw
=
inp
.
data
.
clone
()
out_raw
=
torch
.
empty
(
pos
.
shape
[
0
],
d_model
,
device
=
inp
.
device
,
dtype
=
inp
.
dtype
)
out_raw
.
sum
().
backward
()
#
out_raw.sum().backward()
for
i
,
f
in
enumerate
(
pos
.
cpu
()):
out_raw
[
i
]
=
inp
[
f
%
batch_size
]
_assert_numer
c
ial
([
'out'
],
[
out
],
[
out_raw
],
0
)
_assert_numeri
c
al
([
'out'
],
[
out
],
[
out_raw
],
0
)
# TODO: check grad
if
__name__
==
'__main__'
:
...
...
tests/test_numerical.py
View file @
5680c599
...
...
@@ -17,11 +17,13 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def
_perform_forward
(
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
'torch.FloatTensor'
):
moe
.
zero_grad
()
moe_raw
.
zero_grad
()
inp
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
d_model
).
type
(
data_type
).
cuda
()
if
mp_group
is
not
None
:
group_sender
=
rank
//
mp_group
.
size
()
*
mp_group
.
size
()
torch
.
distributed
.
broadcast
(
inp
,
group_sender
,
group
=
mp_group
)
...
...
@@ -46,15 +48,17 @@ def _perform_forward(
return
moe_out
,
raw_out
,
inp
.
grad
,
inp_raw
.
grad
def
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
):
def
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
max
()
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
if
err
>
precision
:
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
mo
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
raw out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
ro
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
diff ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
{}
\n
"
.
format
((
mo
-
ro
).
abs
(),
err
))
assert
False
...
...
@@ -87,6 +91,7 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear
(
num_expert
,
top_k
,
...
...
@@ -98,6 +103,7 @@ def test_fmoe_linear(
mp_group
,
dp_group
,
world_group
,
data_type
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
torch
.
manual_seed
(
42
+
rank
)
...
...
@@ -105,7 +111,7 @@ def test_fmoe_linear(
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
).
type
(
data_type
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
...
...
@@ -114,7 +120,7 @@ def test_fmoe_linear(
d_hidden
=
d_hidden
,
world_size
=
world_size
,
top_k
=
top_k
,
).
cuda
()
).
type
(
data_type
).
cuda
()
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
...
...
@@ -145,7 +151,7 @@ def test_fmoe_linear(
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
data_type
)
moe_out_list
=
(
...
...
@@ -195,7 +201,10 @@ def test_fmoe_linear(
"h4toh bias grad"
,
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
precision
=
5e-1
if
data_type
==
'torch.HalfTensor'
else
1e-3
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
...
@@ -296,7 +305,7 @@ def test_fmoe(
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
class
MyModule
(
nn
.
Module
):
...
...
@@ -372,7 +381,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names
=
[
"mp grad"
,
"dp grad"
,
"wp grad"
]
_assert_numer
c
ial
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
if
__name__
==
"__main__"
:
...
...
tests/test_zero.py
View file @
5680c599
import
sys
import
torch
from
fmoe.layers
import
_fmoe_general_global_forward
from
fmoe
import
FMoETransformerMLP
from
test_ddp
import
_run_distributed
class
ConstantGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
1
):
...
...
@@ -16,12 +19,34 @@ class ConstantGate(torch.nn.Module):
def
test_zero_fwd
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
_run_distributed
(
'_test_zero_fwd'
,
1
,
{
'num_expert'
:
num_expert
,
'batch_size'
:
batch_size
,
'd_hidden'
:
d_hidden
},
script
=
__file__
)
def
_test_zero_fwd
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
gate
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
).
cuda
()
x
=
_fmoe_general_global_forward
(
inp
,
gate
,
lambda
x
,
y
:
x
,
num_expert
,
world_size
)
def
test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
_run_distributed
(
'_test_zero_transformer'
,
1
,
{
'num_expert'
:
num_expert
,
'batch_size'
:
batch_size
,
'd_hidden'
:
d_hidden
},
script
=
__file__
)
def
test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
model
=
FMoETransformerMLP
(
num_expert
,
d_hidden
,
d_hidden
*
4
,
world_size
,
...
...
@@ -30,9 +55,13 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
torch
.
distributed
.
get_rank
())
args
[
'world_size'
]
=
torch
.
distributed
.
get_world_size
()
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
# test_zero_fwd(world_size=torch.distributed.get_world_size())
test_zero_transformer
(
num_expert
=
16
,
batch_size
=
4096
,
d_hidden
=
1024
,
world_size
=
torch
.
distributed
.
get_world_size
()
)
world_size
=
1
)
print
(
'done'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment