Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d4dd2a6c
Commit
d4dd2a6c
authored
Dec 29, 2020
by
Jiezhong Qiu
Browse files
debuging
parent
93291a7e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
11 deletions
+62
-11
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+49
-8
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+13
-3
No files found.
pytorch/cuda/moe.py
View file @
d4dd2a6c
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
moe_cuda
import
moe_cuda
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
class
MOEFunction
(
Function
):
class
MOEFunction
(
Function
):
@
staticmethod
@
staticmethod
...
@@ -27,29 +27,70 @@ class MOEFunction(Function):
...
@@ -27,29 +27,70 @@ class MOEFunction(Function):
class
MOELayer
(
nn
.
Module
):
class
MOELayer
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
4096
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
4096
):
super
(
MOELayer
,
self
).
__init__
()
super
(
MOELayer
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
pass
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
input
,
gate
):
def
forward
(
self
,
input
,
gate
):
return
MOEFunction
.
apply
(
input
,
gate
,
self
.
weight
)
return
MOEFunction
.
apply
(
input
,
gate
,
self
.
weight
)
batch_size
=
64
class
MOELayer_einsum
(
nn
.
Module
):
num_expert
=
32
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
4096
):
in_feat
=
512
super
(
MOELayer_einsum
,
self
).
__init__
()
out_feat
=
512
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
input
,
gate
):
gate_long
=
gate
.
long
()
#W = self.weight[gate_long] # [batch_size x out_feat x in_feat]
#x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat]
#return x
batch_size
=
input
.
size
(
0
)
x
=
input
.
new_zeros
((
batch_size
,
self
.
out_feat
))
for
i
in
range
(
batch_size
):
x
[
i
]
=
self
.
weight
[
gate_long
[
i
]]
@
input
[
i
]
return
x
batch_size
=
2
num_expert
=
2
in_feat
=
2
out_feat
=
4
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_einsum
=
MOELayer_einsum
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_einsum
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
input
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
input
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
print
(
input
)
print
(
gate
)
output
=
moe
(
input
,
gate
)
output
=
moe
(
input
,
gate
)
print
(
input
)
print
(
gate
)
output_einsum
=
moe_einsum
(
input
,
gate
)
print
(
output
)
print
(
output_einsum
)
y
=
output
.
mean
()
#y = output.mean()
y
.
backward
()
#y.backward()
\ No newline at end of file
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
View file @
d4dd2a6c
...
@@ -151,7 +151,7 @@ void moe_cuda_forward_impl(
...
@@ -151,7 +151,7 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
// setup Aarray, Barray and Carray
// setup Aarray, Barray and Carray
std
::
vector
<
const
scalar_t
*>
aptrs
;
std
::
vector
<
const
scalar_t
*>
aptrs
,
bptrs
;
std
::
vector
<
scalar_t
*>
cptrs
;
std
::
vector
<
scalar_t
*>
cptrs
;
const
scalar_t
**
Aarray
;
const
scalar_t
**
Aarray
;
...
@@ -163,6 +163,7 @@ void moe_cuda_forward_impl(
...
@@ -163,6 +163,7 @@ void moe_cuda_forward_impl(
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
aptrs
.
push_back
(
input
+
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
i
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
}
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
...
@@ -173,14 +174,23 @@ void moe_cuda_forward_impl(
...
@@ -173,14 +174,23 @@ void moe_cuda_forward_impl(
dim3
blockdim
(
256
);
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
const
scalar_t
**
B
=
(
const
scalar_t
**
)
malloc
(
batch_size
*
sizeof
(
const
scalar_t
*
));
checkCudaErrors
(
cudaMemcpy
(
B
,
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyDeviceToHost
));
std
::
cout
<<
weight
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
std
::
cout
<<
B
[
i
]
<<
" "
<<
bptrs
[
i
]
<<
std
::
endl
;
}
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
transb
,
transb
,
1
,
out_feat
,
in_feat
,
1
,
out_feat
,
in_feat
,
&
alpha
,
&
alpha
,
Aarray
,
1
,
Aarray
,
1
,
Barray
,
out
_feat
,
Barray
,
(
transb
==
CUBLAS_OP_T
)
?
out_feat
:
in
_feat
,
&
beta
,
&
beta
,
Carray
,
1
,
Carray
,
1
,
batch_size
));
batch_size
));
...
@@ -234,7 +244,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -234,7 +244,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
//
printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
, topk=%ld
\n", batch_size, num_expert, in_feat, out_feat
, top_k
);
printf
(
"b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment