Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
32e35812
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "15f6b22466d833b9c08583120bcaf953c6240351"
Commit
32e35812
authored
Dec 29, 2020
by
Jiezhong Qiu
Browse files
update
parent
d4dd2a6c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
26 deletions
+33
-26
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+24
-23
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+9
-3
No files found.
pytorch/cuda/moe.py
View file @
32e35812
...
@@ -10,18 +10,18 @@ torch.cuda.manual_seed(42)
...
@@ -10,18 +10,18 @@ torch.cuda.manual_seed(42)
class
MOEFunction
(
Function
):
class
MOEFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
ut
,
gate
,
weight
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
output
=
moe_cuda
.
forward
(
inp
ut
,
gate
,
weight
)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
)
variables
=
[
inp
ut
,
gate
,
weight
]
variables
=
[
inp
,
gate
,
weight
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
grad_inp
ut
,
grad_weight
=
moe_cuda
.
backward
(
grad_inp
,
grad_weight
=
moe_cuda
.
backward
(
grad_out
.
contiguous
(),
*
ctx
.
saved_tensors
)
grad_out
.
contiguous
(),
*
ctx
.
saved_tensors
)
return
grad_inp
ut
,
None
,
grad_weight
return
grad_inp
,
None
,
grad_weight
class
MOELayer
(
nn
.
Module
):
class
MOELayer
(
nn
.
Module
):
...
@@ -39,8 +39,8 @@ class MOELayer(nn.Module):
...
@@ -39,8 +39,8 @@ class MOELayer(nn.Module):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
ut
,
gate
):
def
forward
(
self
,
inp
,
gate
):
return
MOEFunction
.
apply
(
inp
ut
,
gate
,
self
.
weight
)
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
)
class
MOELayer_einsum
(
nn
.
Module
):
class
MOELayer_einsum
(
nn
.
Module
):
...
@@ -58,36 +58,37 @@ class MOELayer_einsum(nn.Module):
...
@@ -58,36 +58,37 @@ class MOELayer_einsum(nn.Module):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
ut
,
gate
):
def
forward
(
self
,
inp
,
gate
):
gate_long
=
gate
.
long
()
gate_long
=
gate
.
long
()
#W = self.weight[gate_long] # [batch_size x out_feat x in_feat]
batch_size
=
inp
.
size
(
0
)
#x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat]
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
#return x
batch_size
=
input
.
size
(
0
)
x
=
input
.
new_zeros
((
batch_size
,
self
.
out_feat
))
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
x
[
i
]
=
self
.
weight
[
gate_long
[
i
]]
@
inp
ut
[
i
]
x
[
i
]
=
self
.
weight
[
gate_long
[
i
]]
@
inp
[
i
]
return
x
return
x
batch_size
=
2
batch_size
=
1
num_expert
=
2
num_expert
=
1
in_feat
=
2
in_feat
=
3
out_feat
=
4
out_feat
=
3
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_einsum
=
MOELayer_einsum
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_einsum
=
MOELayer_einsum
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_einsum
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_einsum
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
input
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
print
(
input
)
print
(
inp
.
type
())
print
(
moe
.
weight
.
data
.
type
())
print
(
inp
)
print
(
gate
)
print
(
gate
)
output
=
moe
(
inp
ut
,
gate
)
output
=
moe
(
inp
,
gate
)
print
(
inp
ut
)
print
(
inp
)
print
(
gate
)
print
(
gate
)
output_einsum
=
moe_einsum
(
inp
ut
,
gate
)
output_einsum
=
moe_einsum
(
inp
.
clone
(),
gate
.
clone
()
)
print
(
output
)
print
(
output
)
print
(
output_einsum
)
print
(
output_einsum
)
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
32e35812
...
@@ -161,9 +161,12 @@ void moe_cuda_forward_impl(
...
@@ -161,9 +161,12 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)));
int
*
gate_host
=
new
int
[
batch_size
];
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
aptrs
.
push_back
(
input
+
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
gate_host
[
i
]
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
}
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
...
@@ -177,9 +180,12 @@ void moe_cuda_forward_impl(
...
@@ -177,9 +180,12 @@ void moe_cuda_forward_impl(
const
scalar_t
**
B
=
(
const
scalar_t
**
)
malloc
(
batch_size
*
sizeof
(
const
scalar_t
*
));
const
scalar_t
**
B
=
(
const
scalar_t
**
)
malloc
(
batch_size
*
sizeof
(
const
scalar_t
*
));
checkCudaErrors
(
cudaMemcpy
(
B
,
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyDeviceToHost
));
checkCudaErrors
(
cudaMemcpy
(
B
,
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyDeviceToHost
));
std
::
cout
<<
weigh
t
<<
std
::
endl
;
std
::
cout
<<
input
<<
" "
<<
weight
<<
" "
<<
outpu
t
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
std
::
cout
<<
B
[
i
]
<<
" "
<<
bptrs
[
i
]
<<
std
::
endl
;
std
::
cout
<<
i
<<
std
::
endl
;
std
::
cout
<<
"A "
<<
aptrs
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"B "
<<
B
[
i
]
<<
" "
<<
bptrs
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"C "
<<
cptrs
[
i
]
<<
std
::
endl
;
}
}
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment