Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0ab605cc
"docs/vscode:/vscode.git/clone" did not exist on "5545bbc54b5a2d5ea1e55e0426075afca1753ee8"
Commit
0ab605cc
authored
Dec 29, 2020
by
Jiezhong Qiu
Browse files
update
parent
b83ac1a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
36 deletions
+12
-36
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+10
-14
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+2
-22
No files found.
pytorch/cuda/moe.py
View file @
0ab605cc
...
...
@@ -11,8 +11,10 @@ torch.cuda.manual_seed(42)
class
MOEFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
)
variables
=
[
inp
,
gate
,
weight
]
out_feat
,
in_feat
=
weight
.
size
()[
1
:]
weight_column_major
=
weight
.
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight_column_major
)
variables
=
[
inp
,
gate
,
weight_column_major
]
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
...
...
@@ -21,7 +23,9 @@ class MOEFunction(Function):
def
backward
(
ctx
,
grad_out
):
grad_inp
,
grad_weight
=
moe_cuda
.
backward
(
grad_out
.
contiguous
(),
*
ctx
.
saved_tensors
)
return
grad_inp
,
None
,
grad_weight
out_feat
,
in_feat
=
grad_weight
.
size
()[
1
:]
grad_weight_row_major
=
grad_weight
.
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
return
grad_inp
,
None
,
grad_weight_row_major
class
MOELayer
(
nn
.
Module
):
...
...
@@ -66,9 +70,9 @@ class MOELayer_einsum(nn.Module):
x
[
i
]
=
self
.
weight
[
gate_long
[
i
]]
@
inp
[
i
]
return
x
batch_size
=
1
num_expert
=
1
in_feat
=
3
batch_size
=
4
num_expert
=
4
in_feat
=
2
out_feat
=
3
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
...
...
@@ -79,15 +83,7 @@ moe_einsum.weight.data = moe.weight.data.clone()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
print
(
inp
.
type
())
print
(
moe
.
weight
.
data
.
type
())
print
(
inp
)
print
(
gate
)
output
=
moe
(
inp
,
gate
)
print
(
inp
)
print
(
gate
)
output_einsum
=
moe_einsum
(
inp
.
clone
(),
gate
.
clone
())
print
(
output
)
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
0ab605cc
...
...
@@ -140,18 +140,13 @@ void moe_cuda_forward_impl(
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
/*
cublasHandle_t handle;
cudaStream_t st;
checkCudaErrors(cudaStreamCreate(&st));
checkCudaErrors(cublasCreate(&handle));
*/
Helper
*
h
=
getHelper
(
num_expert
);
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
// setup Aarray, Barray and Carray
std
::
vector
<
const
scalar_t
*>
aptrs
,
bptrs
;
std
::
vector
<
const
scalar_t
*>
aptrs
;
std
::
vector
<
scalar_t
*>
cptrs
;
const
scalar_t
**
Aarray
;
...
...
@@ -161,12 +156,8 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)));
int
*
gate_host
=
new
int
[
batch_size
];
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
gate_host
[
i
]);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
...
...
@@ -177,17 +168,6 @@ void moe_cuda_forward_impl(
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
*
(
h
->
streams
)
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
const
scalar_t
**
B
=
(
const
scalar_t
**
)
malloc
(
batch_size
*
sizeof
(
const
scalar_t
*
));
checkCudaErrors
(
cudaMemcpy
(
B
,
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyDeviceToHost
));
std
::
cout
<<
input
<<
" "
<<
weight
<<
" "
<<
output
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
std
::
cout
<<
i
<<
std
::
endl
;
std
::
cout
<<
"A "
<<
aptrs
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"B "
<<
B
[
i
]
<<
" "
<<
bptrs
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"C "
<<
cptrs
[
i
]
<<
std
::
endl
;
}
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
h
->
handle
,
CUBLAS_OP_N
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment