Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
5ad13a9e
Commit
5ad13a9e
authored
Dec 17, 2020
by
Jiezhong Qiu
Browse files
update
parent
b6f23ba0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
194 additions
and
3 deletions
+194
-3
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+194
-3
No files found.
pytorch/cuda/moe.cpp
View file @
5ad13a9e
...
...
@@ -16,8 +16,199 @@
const
int
num_stream
=
512
;
// std::vector<torch::Tensor>
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
Aarray
[],
int
lda
,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
Aarray
[],
int
lda
,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
Aarray
[],
int
lda
,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
_half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
size_t
batch_size
,
size_t
top_k
,
size_t
in_feat
,
size_t
out_feat
)
{
cublasHandle_t
handle
;
checkCudaErrors
(
cublasCreate
(
&
handle
));
// setup Aarray, Barray and Carray
std
::
vector
<
scalar_t
*>
aptrs
,
bptrs
,
cptrs
;
scalar_t
**
ptrs
;
checkCudaErrors
(
cudaMalloc
(
&
ptrs
,
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
*
3
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
k
=
0
;
k
<
top_k
;
++
k
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
gate
[
i
*
top_k
+
k
]);
cptrs
.
push_back
(
output
+
out_feat
*
(
i
*
top_k
+
k
));
}
}
checkCudaErrors
(
cudaMemcpy
(
ptrs
,
aptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
ptrs
+
batch_size
*
top_k
,
bptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
ptrs
+
batch_size
*
top_k
*
2
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
1
,
out_feat
,
in_feat
,
&
alpha
,
ptrs
,
1
,
ptrs
+
batch_size
*
top_k
,
out_feat
,
&
beta
,
ptrs
+
batch_size
*
top_k
*
2
,
1
,
batch_size
));
cudaStreamSynchronize
(
st
);
}
void
moe_cuda_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
weight
,
// [N x D_ffn x D_model]
)
{
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
top_k
=
gate
.
size
(
1
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
printf
(
"b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d
\n
"
,
batch_size
,
num_expert
,
d_model
,
d_ffn
,
top_k
);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"moe_cuda_forward"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
size_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
top_k
,
in_feat
,
out_feat
);
}));
cublasHandle_t
handle
;
checkCudaErrors
(
cublasCreate
(
&
handle
));
cudaStream_t
stream
[
num_stream
];
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
&
stream
[
i
]));
}
cudaEvent_t
start
,
stop
;
checkCudaErrors
(
cudaEventCreate
(
&
start
));
checkCudaErrors
(
cudaEventCreate
(
&
stop
));
// Record the start event
checkCudaErrors
(
cudaEventRecord
(
start
,
NULL
));
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
// printf("i=%d j=%d goes to stream %d\n", i, j, s);
checkCudaErrors
(
cublasSetStream
(
handle
,
stream
[
s
]));
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Float
)
{
float
alpha
=
1.0
;
float
beta
=
0.0
;
checkCudaErrors
(
cublasSgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
// m
d_ffn
,
// n
d_model
,
// k
&
alpha
,
input
[
i
].
data_ptr
<
float
>
(),
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
float
>
(),
d_model
,
&
beta
,
output
[
i
][
j
].
data_ptr
<
float
>
(),
1
));
}
else
{
printf
(
"only support float!!!
\n
"
);
}
}
}
// checkCudaErrors(cudaDeviceSynchronize());
// Record the stop event
checkCudaErrors
(
cudaEventRecord
(
stop
,
NULL
));
// Wait for the stop event to complete
checkCudaErrors
(
cudaEventSynchronize
(
stop
));
float
msecTotal
=
0.0
f
;
checkCudaErrors
(
cudaEventElapsedTime
(
&
msecTotal
,
start
,
stop
));
// Compute and print the performance
float
msecPerMatrixMul
=
msecTotal
/
batch_size
/
num_expert
;
double
flopsPerMatrixMul
=
2.0
*
(
double
)
d_model
*
(
double
)
d_ffn
;
double
gigaFlops
=
(
flopsPerMatrixMul
*
1.0e-9
f
)
/
(
msecPerMatrixMul
/
1000.0
f
);
printf
(
"Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops
\n
"
,
gigaFlops
,
msecPerMatrixMul
,
flopsPerMatrixMul
);
// std::cout << output << std::endl;
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
stream
[
i
]));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
// std::vector<torch::Tensor>
void
moe_cuda_forward_v1
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x N]
torch
::
Tensor
weight
,
// [N x D_model x D_ffn]
...
...
@@ -113,9 +304,9 @@ void moe_cuda_forward(
int
main
()
{
int
device
=
2
;
torch
::
Tensor
input
=
torch
::
randn
({
2048
,
512
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2048
,
2
},
torch
::
dtype
(
torch
::
kInt64
)
.
device
(
torch
::
kCUDA
,
device
)
);
torch
::
Tensor
gate
=
torch
::
zeros
({
2048
,
2
},
torch
::
dtype
(
torch
::
kInt64
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
512
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
checkCudaErrors
(
cudaSetDevice
(
device
));
moe_cuda_forward
(
input
,
gate
,
weight
,
bias
);
moe_cuda_forward
_v1
(
input
,
gate
,
weight
,
bias
);
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment