Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
307e0ad9
Commit
307e0ad9
authored
Jan 10, 2021
by
Rick Ho
Browse files
Merge branch 'laekov/batching' into laekov/multigpu
parents
cbd86de8
861b75c1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
512 additions
and
370 deletions
+512
-370
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+55
-12
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+15
-29
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+56
-25
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+58
-48
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+227
-237
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+24
-11
pytorch/cuda/run.sh
pytorch/cuda/run.sh
+3
-3
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+72
-5
pytorch/run_enwik8_base.sh
pytorch/run_enwik8_base.sh
+2
-0
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
307e0ad9
#include <cuda_runtime.h>
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
CudaStreamManager
*
smgr
=
NULL
;
#define SMGR_N_STREAMS 4
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
)
{
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
if
(
!
smgr
)
{
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
smgr
=
new
CudaStreamManager
(
num_expert
);
}
return
smgr
;
}
}
void
CudaStreamManager
::
sync
(
int
i
)
{
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
if
(
i
>
-
1
)
{
return
this
->
handles
[
idx
%
SMGR_N_STREAMS
];
}
void
CudaStreamManager
::
sync
(
int
idx
)
{
for
(
int
i
=
0
;
i
<
idx
&&
i
<
SMGR_N_STREAMS
;
++
i
)
{
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
return
;
}
}
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
}
cudaStreamSynchronize
(
streams
[
i
]);
void
CudaStreamManager
::
setup
(
const
int
device
)
{
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
for
(
size_t
i
=
0
;
i
<
SMGR_N_STREAMS
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
void
CudaStreamManager
::
destroy
()
{
for
(
size_t
i
=
0
;
i
<
SMGR_N_STREAMS
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
streams
[
i
]));
checkCudaErrors
(
cublasDestroy
(
handles
[
i
]));
}
}
delete
[]
streams
;
delete
[]
handles
;
}
}
std
::
unordered_map
<
int
,
CudaStreamManager
*>
smgrs
;
std
::
mutex
smgr_mtx
;
CudaStreamManager
*
getCudaStreamManager
(
const
int
device
)
{
auto
it
=
smgrs
.
find
(
device
);
if
(
it
==
smgrs
.
end
())
{
smgr_mtx
.
lock
();
it
=
smgrs
.
find
(
device
);
if
(
it
==
smgrs
.
end
())
{
auto
smgr
=
new
CudaStreamManager
(
device
);
smgrs
.
insert
(
std
::
pair
<
int
,
CudaStreamManager
*>
(
device
,
smgr
));
smgr_mtx
.
unlock
();
return
smgr
;
}
else
{
smgr_mtx
.
unlock
();
}
}
return
it
->
second
;
}
pytorch/cuda/cuda_stream_manager.h
View file @
307e0ad9
...
@@ -3,44 +3,30 @@
...
@@ -3,44 +3,30 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
class
CudaStreamManager
{
#define MAX_STREAMS 16
public:
int
device
;
struct
CudaStreamManager
{
const
size_t
num_expert
;
cublasHandle_t
*
handles
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
cudaStream_t
*
streams
;
CudaStreamManager
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
public:
streams
=
new
cudaStream_t
[
MAX_STREAMS
];
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
handles
=
new
cublasHandle_t
[
MAX_STREAMS
];
this
->
setup
(
device
);
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasSetStream
(
handles
[
i
],
streams
[
i
]));
}
}
}
~
CudaStreamManager
()
{
void
setup
(
int
);
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
void
sync
(
int
=
0
);
checkCudaErrors
(
cudaStreamDestroy
(
streams
[
i
]));
void
destroy
();
checkCudaErrors
(
cublasDestroy
(
handles
[
i
]));
}
}
inline
cudaStream_t
&
getStream
(
int
idx
)
{
cudaStream_t
stream
(
size_t
=
0
);
return
streams
[
idx
%
MAX_STREAMS
];
cublasHandle_t
handle
(
size_t
=
0
);
}
inline
cublasHandle_t
&
getHandle
(
int
idx
)
{
return
handles
[
idx
%
MAX_STREAMS
];
}
void
sync
(
int
=-
1
);
~
CudaStreamManager
()
{
this
->
destroy
();
}
};
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
CudaStreamManager
*
getCudaStreamManager
(
const
int
device
);
#endif // CUDA_STREAM_MANAGER
#endif // CUDA_STREAM_MANAGER
pytorch/cuda/moe.cpp
View file @
307e0ad9
...
@@ -4,17 +4,27 @@
...
@@ -4,17 +4,27 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
pos
);
torch
::
Tensor
weight1
,
torch
::
Tensor
weight2
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
grad_output
_buf
,
torch
::
Tensor
input
,
torch
::
Tensor
input
_buf
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
,
torch
::
Tensor
weigh
t
);
torch
::
Tensor
expert_coun
t
);
// C++ interface
// C++ interface
...
@@ -23,40 +33,58 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -23,40 +33,58 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
weight1
,
// [num_expert x hidden_feat x in_feat]
torch
::
Tensor
expert_count
// [batch_size]
torch
::
Tensor
weight2
// [num_expert x out_feat x hidden_feat]
)
{
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight1
);
CHECK_INPUT
(
weight2
);
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
return
moe_cuda_forward
(
input
,
gate
,
weight
1
,
weight2
);
return
moe_cuda_forward
(
input
_buf
,
weight
,
expert_count
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
input
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size
]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat
]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
)
{
)
{
CHECK_INPUT
(
grad_output
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
return
moe_cuda_backward
(
grad_output
,
input
,
gate
,
weight
);
return
moe_cuda_backward
(
grad_output
_buf
,
input
_buf
,
weight
,
expert_count
);
}
}
...
@@ -72,6 +100,9 @@ int main() {
...
@@ -72,6 +100,9 @@ int main() {
*/
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"expert_count"
,
&
moe_expert_count
,
"MoE expert count (CUDA)"
);
m
.
def
(
"local_scatter"
,
&
moe_local_scatter
,
"MoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
moe_local_gather
,
"MoE local gather (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
}
pytorch/cuda/moe.py
View file @
307e0ad9
...
@@ -5,86 +5,75 @@ import torch
...
@@ -5,86 +5,75 @@ import torch
import
moe_cuda
import
moe_cuda
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
class
MOEFunction
(
Function
):
class
MOEFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
1
,
weight2
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
# out_feat, in_feat = weight.size()[1:]
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight1
,
weight2
)
expert_count
,
pos
=
moe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
])
variables
=
[
inp
,
gate
,
weight1
,
weight2
]
input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
expert_count
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
# print("grad_out", grad_out)
input_buf
,
gate
,
weight
,
expert_count
,
pos
=
ctx
.
saved_tensors
# print("input", ctx.saved_tensors[0])
grad_
inp
,
grad_weight
=
moe_cuda
.
backward
(
grad_
out_buf
,
=
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_
out
.
contiguous
(),
*
ctx
.
saved_tensors
)
grad_
inp_buf
,
grad_weight
=
moe_cuda
.
backward
(
out_feat
,
in_feat
=
grad_weight
.
size
()[
1
:]
grad_out_buf
,
input_buf
,
weight
,
expert_count
)
# print("grad_weight_column_major", grad_weight.flatten()
)
grad_inp
,
=
moe_cuda
.
local_gather
(
grad_inp_buf
,
pos
)
grad_weight_row_major
=
grad_weight
.
view
(
-
1
,
in_feat
,
out_feat
).
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
return
grad_inp
,
None
,
grad_weight
_row_major
return
grad_inp
,
None
,
grad_weight
class
MOELayer
(
nn
.
Module
):
class
MOELayer
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
MOELayer
,
self
).
__init__
()
super
(
MOELayer
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
weight1
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight1
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
,
gate
):
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight1
,
self
.
weight
2
)
return
MOEFunction
.
apply
(
inp
,
gate
.
int
()
,
self
.
weight
)
class
MOELayer_raw
(
nn
.
Module
):
class
MOELayer_raw
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
MOELayer_raw
,
self
).
__init__
()
super
(
MOELayer_raw
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
weight1
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden
_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out
_feat
)
# print(linear.weight.shape)
# print(linear.weight.shape)
self
.
weight1
.
data
[
i
]
=
(
linear
.
weight
.
data
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
(
linear
.
weight
.
data
)
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
,
gate
):
gate_long
=
gate
.
long
()
gate_long
=
gate
.
long
()
batch_size
=
inp
.
size
(
0
)
batch_size
=
inp
.
size
(
0
)
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
# print(self.weight2)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
hid
=
inp
[
i
]
@
self
.
weight1
[
gate_long
[
i
]].
t
()
x
[
i
]
=
inp
[
i
]
@
self
.
weight
[
gate_long
[
i
]].
t
()
# print(hid)
x
[
i
]
=
hid
@
self
.
weight2
[
gate_long
[
i
]].
t
()
return
x
return
x
...
@@ -93,28 +82,24 @@ def test_module(moe, linear, inp, gate):
...
@@ -93,28 +82,24 @@ def test_module(moe, linear, inp, gate):
moe
.
zero_grad
()
moe
.
zero_grad
()
x
=
(
linear
(
inp
))
x
=
(
linear
(
inp
))
output
=
moe
(
x
,
gate
)
output
=
moe
(
x
,
gate
)
# print(output)
if
torch
.
distributed
.
get_rank
()
==
1
:
print
(
output
)
return
output
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
def
test
():
def
test
():
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
batch_size
=
4
batch_size
=
4
num_expert
=
2
num_expert
=
2
in_feat
=
6
in_feat
=
6
hidden_feat
=
12
out_feat
=
7
out_feat
=
7
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
.
weight1
.
data
=
moe
.
weight1
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight2
.
data
=
moe
.
weight2
.
data
.
clone
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
...
@@ -124,11 +109,36 @@ def test():
...
@@ -124,11 +109,36 @@ def test():
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
]
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'{} abs err {}'
.
format
(
name
,
err
))
def
test_dp
():
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
batch_size
=
6
num_expert
=
4
in_feat
=
2
out_feat
=
3
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
print
(
"data parallel of a nn.Linear model"
)
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
linear_dp
=
torch
.
nn
.
DataParallel
(
linear
,
device_ids
=
[
0
,
1
,
2
])
output
=
linear_dp
(
inp
)
print
(
"successful!"
)
print
(
"data parallel of our MoE model"
)
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
[
0
,
1
,
2
])
for
i
in
range
(
5
):
output
=
moe_dp
(
inp
,
gate
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
test
()
test
()
# test_dp()
pytorch/cuda/moe_cuda_kernel.cu
View file @
307e0ad9
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <mpi.h>
#include <mpi.h>
...
@@ -20,10 +20,6 @@
...
@@ -20,10 +20,6 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_DEBUG
#define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
...
@@ -34,10 +30,9 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...
@@ -34,10 +30,9 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_scatter_kernel
(
in
t
wid
,
int
*
pos
,
void
batch_scatter_kernel
(
size_
t
wid
,
const
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
...
@@ -46,55 +41,15 @@ void batch_scatter_kernel(int wid, int* pos,
...
@@ -46,55 +41,15 @@ void batch_scatter_kernel(int wid, int* pos,
}
}
}
}
template
<
typename
scalar_t
>
void
moe_cuda_expert_count_impl
(
__global__
void
batch_gather_kernel
(
int
wid
,
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
scalar_t
print_first_float
(
scalar_t
*
d_ptr
)
{
scalar_t
v
;
cudaMemcpy
(
&
v
,
d_ptr
,
sizeof
(
scalar_t
),
cudaMemcpyDeviceToHost
);
return
v
;
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
int
*
d_gate
,
const
int
*
d_gate
,
const
scalar_t
*
weight1
,
int
*
expert_count
,
const
scalar_t
*
weight2
,
int
*
d_pos
,
scalar_t
*
output
,
const
size_t
num_expert
,
const
size_t
batch_size
,
const
size_t
batch_size
)
{
const
size_t
in_feat
,
const
size_t
hidden_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
auto
h
=
getCudaStreamManager
(
num_expert
);
auto
cm
=
getCommManager
();
int
tot_expert
=
num_expert
*
cm
->
size
;
#ifdef MOE_BREAKDOWN
timestamp
(
t_init
);
#endif
scalar_t
*
local_input_buf
,
*
local_output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
int
*
gate
=
new
int
[
batch_size
];
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_count
=
new
int
[
tot_expert
],
*
expert_ptr
=
new
int
[
tot
_expert
];
int
*
expert_ptr
=
new
int
[
num
_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
tot
_expert
);
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num
_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
...
@@ -108,8 +63,6 @@ void moe_cuda_forward_impl(
...
@@ -108,8 +63,6 @@ void moe_cuda_forward_impl(
}
}
int
*
pos
=
new
int
[
batch_size
];
int
*
pos
=
new
int
[
batch_size
];
int
*
d_pos
;
checkCudaErrors
(
cudaMalloc
(
&
d_pos
,
sizeof
(
int
)
*
batch_size
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
...
@@ -120,40 +73,11 @@ void moe_cuda_forward_impl(
...
@@ -120,40 +73,11 @@ void moe_cuda_forward_impl(
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
expert_ptr
;
}
int
*
all_expert_count
=
new
int
[
tot_expert
];
void
moe_cuda_global_scatter
()
{
MPI_Alltoall
(
expert_count
,
num_expert
,
MPI_INT
,
all_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
int
*
expert_n
=
new
int
[
num_expert
];
int
expert_sz
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
expert_n
[
i
]
=
0
;
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
expert_n
[
i
]
+=
all_expert_count
[
j
*
num_expert
+
i
];
}
expert_sz
+=
expert_n
[
i
];
}
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
if
(
expert_sz
)
{
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
}
#ifdef MOE_BREAKDOWN
timestamp
(
t_expert
);
fprintf
(
stderr
,
"Expert asn %d time %.3lf us
\n
"
,
expert_sz
,
getDuration
(
t_init
,
t_expert
)
*
1e6
);
#endif
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
local_input_buf
);
h
->
sync
(
0
);
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
if
(
cm
->
size
>
1
)
{
if
(
cm
->
size
>
1
)
{
if
(
expert_sz
)
{
if
(
expert_sz
)
{
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
...
@@ -192,58 +116,137 @@ void moe_cuda_forward_impl(
...
@@ -192,58 +116,137 @@ void moe_cuda_forward_impl(
input_buf
=
local_input_buf
;
input_buf
=
local_input_buf
;
output_buf
=
local_output_buf
;
output_buf
=
local_output_buf
;
}
}
}
h
->
sync
(
0
);
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
int
*
d_pos
,
scalar_t
*
input_buf
,
const
size_t
batch_size
,
const
size_t
in_feat
,
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
smgr
->
sync
(
1
);
}
#ifdef MOE_BREAKDOWN
template
<
typename
scalar_t
>
timestamp
(
t_scatter
);
__global__
fprintf
(
stderr
,
"Scatter time %.3lf us
\n
"
,
getDuration
(
t_expert
,
t_scatter
)
*
void
batch_gather_kernel
(
size_t
wid
,
const
int
*
pos
,
1e6
);
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
#endif
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
,
CudaStreamManager
*
smgr
)
{
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
->
sync
(
1
);
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
scalar_t
*
output_buf
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_n
[
i
]
==
0
)
{
if
(
expert_n
[
i
]
==
0
)
{
continue
;
continue
;
}
}
#ifdef MOE_DEBUG_SCATTER
fprintf
(
stderr
,
"worker %d gemm %d sz %d offset %d
\n
"
,
cm
->
rank
,
i
,
expert_n
[
i
],
ptr
);
// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
hidden
_feat
,
expert_
n
[
i
],
in_feat
,
out
_feat
,
expert_
count
[
i
],
in_feat
,
&
alpha
,
&
alpha
,
weight
1
+
i
*
in_feat
*
hidden
_feat
,
in_feat
,
weight
+
i
*
in_feat
*
out
_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
&
beta
,
hidden
_buf
+
hidden
_feat
*
ptr
,
hidden
_feat
output
_buf
+
out
_feat
*
ptr
,
out
_feat
));
));
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
ptr
+=
expert_count
[
i
];
CUBLAS_OP_T
,
}
smgr
->
sync
(
num_expert
);
}
template
<
typename
scalar_t
>
void
moe_cuda_backward_impl
(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out
_feat
,
expert_
n
[
i
],
hidden
_feat
,
in
_feat
,
expert_
count
[
i
],
out
_feat
,
&
alpha
,
&
alpha
,
weight
2
+
i
*
hidde
n_feat
*
out_feat
,
hidde
n_feat
,
weight
+
i
*
i
n_feat
*
out_feat
,
i
n_feat
,
hidden_buf
+
hidden_feat
*
ptr
,
hidden
_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out
_feat
,
&
beta
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
grad_input_buf
+
in_feat
*
ptr
,
in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
input_buf
+
in_feat
*
ptr
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
));
ptr
+=
expert_n
[
i
];
ptr
+=
expert_n
[
i
];
}
}
h
->
sync
();
smgr
->
sync
(
num_expert
);
}
#ifdef MOE_BREAKDOWN
timestamp
(
t_mm
);
fprintf
(
stderr
,
"GeMM time %.3lf us
\n
"
,
getDuration
(
t_scatter
,
t_mm
)
*
1e6
);
#endif
void
moe_cuda_global_gather
()
{
if
(
cm
->
size
>
1
)
{
if
(
cm
->
size
>
1
)
{
int
send_ptr
=
0
;
int
send_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
...
@@ -273,108 +276,105 @@ void moe_cuda_forward_impl(
...
@@ -273,108 +276,105 @@ void moe_cuda_forward_impl(
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
}
}
#ifdef MOE_BREAKDOWN
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
h
->
sync
(
0
);
torch
::
Tensor
gate
,
timestamp
(
t_gather
);
size_t
num_expert
)
{
fprintf
(
stderr
,
"Gather time %.3lf us
\n
"
,
getDuration
(
t_mm
,
t_gather
)
*
const
auto
batch_size
=
gate
.
size
(
0
);
1e6
);
#endif
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
local_output_buf
,
output
);
h
->
sync
(
0
);
#ifdef MOE_BREAKDOWN
timestamp
(
t_end
);
fprintf
(
stderr
,
"Local gather %.3lf us
\n
"
,
getDuration
(
t_gather
,
t_end
)
*
1e6
);
fprintf
(
stderr
,
"Overall time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_end
)
*
1e6
);
#endif
if
(
expert_sz
)
{
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
cudaFree
(
hidden_buf
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
if
(
cm
->
size
>
1
)
{
cudaFree
(
input_buf
);
auto
pos_options
=
torch
::
TensorOptions
()
cudaFree
(
output_buf
);
.
device
(
gate
.
device
())
}
.
dtype
(
torch
::
kInt32
);
}
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
cudaFree
(
local_input_buf
);
moe_cuda_expert_count_impl
(
cudaFree
(
local_output_buf
);
gate
.
data_ptr
<
int
>
(),
cudaFree
(
d_pos
);
expert_count
.
data_ptr
<
int
>
(),
delete
[]
pos
;
pos
.
data_ptr
<
int
>
(),
delete
[]
gate
;
num_expert
,
batch_size
);
return
{
expert_count
,
pos
};
}
}
template
<
typename
scalar_t
>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
void
moe_cuda_grad_weight
(
torch
::
Tensor
input
,
const
scalar_t
*
input
,
torch
::
Tensor
pos
)
{
const
int
*
gate
,
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
scalar_t
*
grad_output
,
const
auto
batch_size
=
input
.
size
(
0
);
scalar_t
*
grad_weight
,
// [num_expert x out_feat x in_feat]
const
auto
in_feat
=
input
.
size
(
1
);
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
auto
h
=
getCudaStreamManager
(
num_exper
t
);
auto
input_buf
=
torch
::
empty_like
(
inpu
t
);
int
*
gate_host
=
new
int
[
batch_size
];
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
scalar_t
alpha
=
1
,
beta
=
1
;
([
&
]
{
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
moe_cuda_local_scatter_impl
<
scalar_t
>
(
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
input
.
data_ptr
<
scalar_t
>
(),
checkCudaErrors
(
cublasSetStream
(
h
->
handles
[
0
],
*
(
h
->
streams
+
gate_host
[
i
])));
pos
.
data_ptr
<
int
>
(),
checkCudaErrors
(
cublasXgemm
(
h
->
handles
[
0
],
input_buf
.
data_ptr
<
scalar_t
>
(),
CUBLAS_OP_N
,
batch_size
,
CUBLAS_OP_T
,
out_feat
,
in_feat
,
in_feat
,
1
,
smgr
);
&
alpha
,
}));
grad_output
+
i
*
out_feat
,
return
{
input_buf
,};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
output_buf
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
output
=
torch
::
empty_like
(
output_buf
);
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
out_feat
,
input
+
i
*
in_feat
,
smgr
);
in_feat
,
}));
&
beta
,
return
{
output
,};
grad_weight
+
gate_host
[
i
]
*
out_feat
*
in_feat
,
out_feat
));
}
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
h
->
streams
+
i
)));
}
delete
[]
gate_host
;
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight1
,
torch
::
Tensor
expert_count
torch
::
Tensor
weight2
)
{
)
{
const
auto
batch_size
=
input
.
size
(
0
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
()
);
const
auto
num_expert
=
weight1
.
size
(
0
);
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
out_fea
t
=
weight
2
.
size
(
1
);
const
auto
num_exper
t
=
weight
.
size
(
0
);
const
auto
hidden
_feat
=
weight
1
.
size
(
1
);
const
auto
out
_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
1
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
hidden_feat
,
out_feat
);
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
int
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
in_feat
,
hidden_feat
,
out_feat
,
out_feat
,
num_expert
num_expert
,
smgr
);
);
}));
}));
...
@@ -382,53 +382,43 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -382,53 +382,43 @@ std::vector<torch::Tensor> moe_cuda_forward(
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
input
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size
]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat
]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
)
{
)
{
const
auto
batch_size
=
input
.
size
(
0
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
grad_input
=
grad_output
.
new_
zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_input
_buf
=
grad_output
_buf
.
new_
empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output
.
new_
zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
auto
grad_weight
=
grad_output
_buf
.
new_
empty
({
num_expert
,
out_feat
,
in_feat
});
// grad_input is easy to compute, exactly the same as forward
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
/* TODO: Backward currently brokenn
moe_cuda_backward_impl
<
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
moe_cuda_forward_impl<scalar_t>(
input_buf
.
data_ptr
<
scalar_t
>
(),
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
grad_input.data_ptr<scalar_t>(),
expert_count
.
data_ptr
<
int
>
(),
batch_size,
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_grad_weight
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
grad_output
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
num_expert
,
smgr
);
);
}));
}));
return
{
grad_input
,
grad_weight
};
return
{
grad_input
_buf
,
grad_weight
};
}
}
...
...
pytorch/cuda/moe_test.py
View file @
307e0ad9
...
@@ -4,48 +4,61 @@ import time
...
@@ -4,48 +4,61 @@ import time
import
sys
import
sys
dev_name
=
'cuda:0'
def
perf
():
def
perf
():
torch
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
())
torch
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
())
torch
.
cuda
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
())
torch
.
cuda
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
())
batch_size
=
int
(
sys
.
argv
[
1
])
batch_size
=
int
(
sys
.
argv
[
1
])
i
o
_feat
=
int
(
sys
.
argv
[
2
])
i
n
_feat
=
int
(
sys
.
argv
[
2
])
hidden
_feat
=
int
(
sys
.
argv
[
3
])
out
_feat
=
int
(
sys
.
argv
[
3
])
num_expert
=
int
(
sys
.
argv
[
4
])
num_expert
=
int
(
sys
.
argv
[
4
])
inp
=
torch
.
rand
(
batch_size
,
io_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
io_feat
).
cuda
(
dev_name
)
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
(
dev_name
)
moe
=
MOELayer
(
num_expert
,
io_feat
,
hidden_feat
,
io_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
(
dev_name
)
moe
.
train
()
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
n_runs
=
16
n_runs
=
16
tott
=
0.
tott
=
0.
backt
=
0.
maxt
=
0.
maxt
=
0.
sqtot
=
0.
sqtot
=
0.
for
i
in
range
(
n_runs
):
for
i
in
range
(
n_runs
):
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
(
dev_name
)
ts
=
time
.
time
()
ts
=
time
.
time
()
o
=
moe
(
inp
,
gate
)
o
=
moe
(
inp
,
gate
)
te
=
time
.
time
()
te
=
time
.
time
()
loss
=
o
.
sum
()
bts
=
time
.
time
()
loss
.
backward
()
bte
=
time
.
time
()
tott
+=
te
-
ts
tott
+=
te
-
ts
sqtot
+=
(
te
-
ts
)
**
2
sqtot
+=
(
te
-
ts
)
**
2
maxt
=
max
(
maxt
,
te
-
ts
)
maxt
=
max
(
maxt
,
te
-
ts
)
backt
=
bte
-
bts
gflops
=
2e-9
*
n_runs
*
i
o
_feat
*
hidden
_feat
*
2
*
batch_size
/
tott
gflops
=
2e-9
*
n_runs
*
i
n
_feat
*
out
_feat
*
batch_size
/
tott
print
(
'Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'
.
format
(
print
(
'Time mean/max/stdev
/back {:.3f}
{:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'
.
format
(
tott
*
1e3
/
n_runs
,
maxt
*
1e3
,
tott
*
1e3
/
n_runs
,
maxt
*
1e3
,
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
/
n_runs
,
gflops
))
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
/
n_runs
,
backt
*
1e3
/
n_runs
,
gflops
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
pytorch/cuda/run.sh
View file @
307e0ad9
...
@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
...
@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
if
[
-z
$1
]
if
[
-z
$1
]
then
then
python moe.py
python
3
moe.py
elif
[
.
$1
=
'.test_all'
]
elif
[
.
$1
=
'.test_all'
]
then
then
for
nexp
in
1 2 4
for
nexp
in
1 2 4
...
@@ -20,11 +20,11 @@ then
...
@@ -20,11 +20,11 @@ then
for
bs
in
4 16 64 256 512 1024 2048 4096
for
bs
in
4 16 64 256 512 1024 2048 4096
do
do
echo
$bs
$nexp
${
inf
}
x
${
ouf
}
echo
$bs
$nexp
${
inf
}
x
${
ouf
}
python moe_test.py
$bs
$inf
$ouf
$nexp
python
3
moe_test.py
$bs
$inf
$ouf
$nexp
done
done
done
done
done
done
done
done
else
else
python
$@
2>logs/
$OMPI_COMM_WORLD_RANK
.log
python
3
$@
2>logs/
$OMPI_COMM_WORLD_RANK
.log
fi
fi
pytorch/mem_transformer.py
View file @
307e0ad9
...
@@ -9,6 +9,8 @@ import torch.nn as nn
...
@@ -9,6 +9,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# import torch_sparse
# import torch_sparse
from
cuda.moe
import
MOELayer
sys
.
path
.
append
(
'utils'
)
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
...
@@ -31,9 +33,74 @@ class PositionalEmbedding(nn.Module):
...
@@ -31,9 +33,74 @@ class PositionalEmbedding(nn.Module):
else
:
else
:
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
class
MoEPositionwiseFF
(
nn
.
Module
):
class
CustomizedMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
2
,
num_expert
=
32
):
super
(
CustomizedMoEPositionwiseFF
,
self
).
__init__
()
print
(
"CustomizedMoEPositionwiseFF num_expert=%d top_k=%d"
%
(
num_expert
,
top_k
))
self
.
top_k
=
top_k
assert
num_expert
>=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
)
self
.
moe1
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_model
+
1
,
out_feat
=
d_inner
)
self
.
moe2
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_inner
+
1
,
out_feat
=
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
pass
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
# (BxL) x 1 x top_k
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
,
self
.
top_k
)
core_out
=
[]
inp
=
inp
.
view
(
-
1
,
self
.
d_model
)
inp
=
F
.
pad
(
inp
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
for
i
in
range
(
self
.
top_k
):
gate_idx
=
gate_top_k_idx
[:,
i
].
contiguous
()
x
=
self
.
moe1
(
inp
,
gate_idx
)
x
=
self
.
dropout
(
F
.
relu
(
x
))
x
=
F
.
pad
(
x
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
x
=
self
.
moe2
(
x
,
gate_idx
)
x
=
self
.
dropout
(
x
)
# (BxL) x d_model
core_out
.
append
(
x
.
unsqueeze
(
1
))
# (BxL) x 1 x d_model
core_out
=
torch
.
cat
(
core_out
,
dim
=
1
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFF
,
self
).
__init__
()
super
(
MoEPositionwiseFF
Raw
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -820,7 +887,7 @@ class DecoderLayer(nn.Module):
...
@@ -820,7 +887,7 @@ class DecoderLayer(nn.Module):
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -840,7 +907,7 @@ class RelLearnableDecoderLayer(nn.Module):
...
@@ -840,7 +907,7 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
**
kwargs
)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
@@ -861,7 +928,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -861,7 +928,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
pytorch/run_enwik8_base.sh
View file @
307e0ad9
#!/bin/bash
#!/bin/bash
export
LD_LIBRARY_PATH
=
/home/jiezhong/miniconda3/lib:/usr/local/cuda/lib64:
$LD_LIBRARY_PATH
if
[[
$1
==
'train'
]]
;
then
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
echo
'Run training...'
python train.py
\
python train.py
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment