Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
eb47044a
Commit
eb47044a
authored
Dec 23, 2020
by
Jiezhong Qiu
Browse files
topk=1
parent
3a458fa7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
18 deletions
+12
-18
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+1
-1
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+11
-17
No files found.
pytorch/cuda/moe.cpp
View file @
eb47044a
...
...
@@ -18,7 +18,7 @@ std::vector<torch::Tensor> moe1_cuda_forward(
std
::
vector
<
torch
::
Tensor
>
moe1_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B
x K
]
torch
::
Tensor
gate
,
// [B]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
)
{
CHECK_INPUT
(
input
);
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
eb47044a
...
...
@@ -74,7 +74,6 @@ void moe1_cuda_forward_impl(
const
scalar_t
*
weight
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
top_k
,
const
size_t
in_feat
,
const
size_t
out_feat
)
{
...
...
@@ -91,24 +90,21 @@ void moe1_cuda_forward_impl(
const
scalar_t
**
Aarray
;
const
scalar_t
**
Barray
;
scalar_t
**
Carray
;
checkCudaErrors
(
cudaMalloc
(
&
Aarray
,
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
));
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
));
checkCudaErrors
(
cudaMalloc
(
&
Aarray
,
batch_size
*
sizeof
(
const
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
k
=
0
;
k
<
top_k
;
++
k
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
// bptrs.push_back(weight + out_feat * in_feat * gate[i * top_k + k]);
cptrs
.
push_back
(
output
+
out_feat
*
(
i
*
top_k
+
k
));
}
aptrs
.
push_back
(
input
+
in_feat
*
i
);
cptrs
.
push_back
(
output
+
out_feat
*
i
);
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
),
cudaMemcpyHostToDevice
));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors
(
cudaMemcpy
(
Carray
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
Carray
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
),
cudaMemcpyHostToDevice
));
dim3
griddim
(
CEIL
(
batch_size
*
top_k
,
256
));
dim3
griddim
(
CEIL
(
batch_size
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
st
>>>
(
batch_size
*
top_k
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
st
>>>
(
batch_size
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
handle
,
...
...
@@ -120,7 +116,7 @@ void moe1_cuda_forward_impl(
Barray
,
out_feat
,
&
beta
,
Carray
,
1
,
batch_size
*
top_k
));
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
st
));
checkCudaErrors
(
cudaStreamDestroy
(
st
));
...
...
@@ -133,13 +129,12 @@ std::vector<torch::Tensor> moe1_cuda_forward(
torch
::
Tensor
gate
,
torch
::
Tensor
weight
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
top_k
=
gate
.
size
(
1
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe1_forward_cuda"
,
([
&
]
{
moe1_cuda_forward_impl
<
scalar_t
>
(
...
...
@@ -148,7 +143,6 @@ std::vector<torch::Tensor> moe1_cuda_forward(
weight
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
top_k
,
in_feat
,
out_feat
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment