Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2b5672e5
"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "1e5f7a2f9e30f3be8bf1ea215bcd984869023266"
Commit
2b5672e5
authored
Dec 17, 2020
by
Jiezhong Qiu
Browse files
update
parent
b7c1b308
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
254 deletions
+158
-254
.gitignore
.gitignore
+2
-1
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+8
-253
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+148
-0
No files found.
.gitignore
View file @
2b5672e5
...
@@ -3,4 +3,5 @@ data/
...
@@ -3,4 +3,5 @@ data/
libtorch-shared-with-deps-*
libtorch-shared-with-deps-*
pytorch/cuda/build
pytorch/cuda/build
exp/
exp/
.vscode/
.vscode/
\ No newline at end of file
a.out
\ No newline at end of file
pytorch/cuda/moe.cpp
View file @
2b5672e5
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
// CUDA runtime
// CUDA runtime
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
...
@@ -14,100 +15,10 @@
...
@@ -14,100 +15,10 @@
#include <helper_cuda.h>
#include <helper_cuda.h>
const
int
num_stream
=
512
;
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
Aarray
[],
int
lda
,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
Aarray
[],
int
lda
,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
Aarray
[],
int
lda
,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
_half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
)
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
size_t
batch_size
,
size_t
top_k
,
size_t
in_feat
,
size_t
out_feat
)
{
cublasHandle_t
handle
;
checkCudaErrors
(
cublasCreate
(
&
handle
));
// setup Aarray, Barray and Carray
std
::
vector
<
scalar_t
*>
aptrs
,
bptrs
,
cptrs
;
scalar_t
**
ptrs
;
checkCudaErrors
(
cudaMalloc
(
&
ptrs
,
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
*
3
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
k
=
0
;
k
<
top_k
;
++
k
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
bptrs
.
push_back
(
weight
+
out_feat
*
in_feat
*
gate
[
i
*
top_k
+
k
]);
cptrs
.
push_back
(
output
+
out_feat
*
(
i
*
top_k
+
k
));
}
}
checkCudaErrors
(
cudaMemcpy
(
ptrs
,
aptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
ptrs
+
batch_size
*
top_k
,
bptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
checkCudaErrors
(
cudaMemcpy
(
ptrs
+
batch_size
*
top_k
*
2
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
1
,
out_feat
,
in_feat
,
&
alpha
,
ptrs
,
1
,
ptrs
+
batch_size
*
top_k
,
out_feat
,
&
beta
,
ptrs
+
batch_size
*
top_k
*
2
,
1
,
batch_size
));
cudaStreamSynchronize
(
st
);
}
void
moe_cuda_forward
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
gate
,
// [B x K]
torch
::
Tensor
weight
,
// [N x D_ffn x D_model]
torch
::
Tensor
weight
// [N x D_ffn x D_model]
)
{
)
{
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
...
@@ -120,10 +31,10 @@ void moe_cuda_forward(
...
@@ -120,10 +31,10 @@ void moe_cuda_forward(
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
printf
(
"b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d
\n
"
,
batch_size
,
num_expert
,
d_model
,
d_ffn
,
top_k
);
printf
(
"b=%d, expert=%d, in_feat (d_model)=%d, out_feat (d_ffn)=%d, topk=%d
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
,
top_k
);
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
auto
output
=
input
.
new_zeros
({
batch_size
,
top_k
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"moe_cuda_forward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_
type
(),
"moe_cuda_forward"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
size_t
>
(),
gate
.
data_ptr
<
size_t
>
(),
...
@@ -135,164 +46,9 @@ void moe_cuda_forward(
...
@@ -135,164 +46,9 @@ void moe_cuda_forward(
out_feat
out_feat
);
);
}));
}));
return
{
output
,
};
cublasHandle_t
handle
;
checkCudaErrors
(
cublasCreate
(
&
handle
));
cudaStream_t
stream
[
num_stream
];
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
&
stream
[
i
]));
}
cudaEvent_t
start
,
stop
;
checkCudaErrors
(
cudaEventCreate
(
&
start
));
checkCudaErrors
(
cudaEventCreate
(
&
stop
));
// Record the start event
checkCudaErrors
(
cudaEventRecord
(
start
,
NULL
));
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
// printf("i=%d j=%d goes to stream %d\n", i, j, s);
checkCudaErrors
(
cublasSetStream
(
handle
,
stream
[
s
]));
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Float
)
{
float
alpha
=
1.0
;
float
beta
=
0.0
;
checkCudaErrors
(
cublasSgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
// m
d_ffn
,
// n
d_model
,
// k
&
alpha
,
input
[
i
].
data_ptr
<
float
>
(),
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
float
>
(),
d_model
,
&
beta
,
output
[
i
][
j
].
data_ptr
<
float
>
(),
1
));
}
else
{
printf
(
"only support float!!!
\n
"
);
}
}
}
// checkCudaErrors(cudaDeviceSynchronize());
// Record the stop event
checkCudaErrors
(
cudaEventRecord
(
stop
,
NULL
));
// Wait for the stop event to complete
checkCudaErrors
(
cudaEventSynchronize
(
stop
));
float
msecTotal
=
0.0
f
;
checkCudaErrors
(
cudaEventElapsedTime
(
&
msecTotal
,
start
,
stop
));
// Compute and print the performance
float
msecPerMatrixMul
=
msecTotal
/
batch_size
/
num_expert
;
double
flopsPerMatrixMul
=
2.0
*
(
double
)
d_model
*
(
double
)
d_ffn
;
double
gigaFlops
=
(
flopsPerMatrixMul
*
1.0e-9
f
)
/
(
msecPerMatrixMul
/
1000.0
f
);
printf
(
"Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops
\n
"
,
gigaFlops
,
msecPerMatrixMul
,
flopsPerMatrixMul
);
// std::cout << output << std::endl;
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
stream
[
i
]));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
// std::vector<torch::Tensor>
void
moe_cuda_forward_v1
(
torch
::
Tensor
input
,
// [B x D_model]
torch
::
Tensor
gate
,
// [B x N]
torch
::
Tensor
weight
,
// [N x D_model x D_ffn]
torch
::
Tensor
bias
// [N x D_ffn]
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
num_expert
=
gate
.
size
(
1
);
const
auto
d_model
=
weight
.
size
(
1
);
const
auto
d_ffn
=
weight
.
size
(
2
);
printf
(
"b=%d, expert=%d, d_model=%d, d_ffn=%d
\n
"
,
batch_size
,
num_expert
,
d_model
,
d_ffn
);
auto
output
=
input
.
new_zeros
({
batch_size
,
num_expert
,
d_ffn
});
cublasHandle_t
handle
;
checkCudaErrors
(
cublasCreate
(
&
handle
));
cudaStream_t
stream
[
num_stream
];
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
&
stream
[
i
]));
}
cudaEvent_t
start
,
stop
;
checkCudaErrors
(
cudaEventCreate
(
&
start
));
checkCudaErrors
(
cudaEventCreate
(
&
stop
));
// Record the start event
checkCudaErrors
(
cudaEventRecord
(
start
,
NULL
));
size_t
s
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
num_expert
;
++
j
)
{
s
=
(
i
*
num_expert
+
j
)
%
num_stream
;
// printf("i=%d j=%d goes to stream %d\n", i, j, s);
checkCudaErrors
(
cublasSetStream
(
handle
,
stream
[
s
]));
if
(
input
.
scalar_type
()
==
torch
::
ScalarType
::
Float
)
{
float
alpha
=
1.0
;
float
beta
=
0.0
;
checkCudaErrors
(
cublasSgemm
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
1
,
// m
d_ffn
,
// n
d_model
,
// k
&
alpha
,
input
[
i
].
data_ptr
<
float
>
(),
1
,
weight
.
index
(
gate
[
i
][
j
]).
data_ptr
<
float
>
(),
d_model
,
&
beta
,
output
[
i
][
j
].
data_ptr
<
float
>
(),
1
));
}
else
{
printf
(
"only support float!!!
\n
"
);
}
}
}
// checkCudaErrors(cudaDeviceSynchronize());
// Record the stop event
checkCudaErrors
(
cudaEventRecord
(
stop
,
NULL
));
// Wait for the stop event to complete
checkCudaErrors
(
cudaEventSynchronize
(
stop
));
float
msecTotal
=
0.0
f
;
checkCudaErrors
(
cudaEventElapsedTime
(
&
msecTotal
,
start
,
stop
));
// Compute and print the performance
float
msecPerMatrixMul
=
msecTotal
/
batch_size
/
num_expert
;
double
flopsPerMatrixMul
=
2.0
*
(
double
)
d_model
*
(
double
)
d_ffn
;
double
gigaFlops
=
(
flopsPerMatrixMul
*
1.0e-9
f
)
/
(
msecPerMatrixMul
/
1000.0
f
);
printf
(
"Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops
\n
"
,
gigaFlops
,
msecPerMatrixMul
,
flopsPerMatrixMul
);
// std::cout << output << std::endl;
for
(
size_t
i
=
0
;
i
<
num_stream
;
++
i
)
{
checkCudaErrors
(
cudaStreamDestroy
(
stream
[
i
]));
}
checkCudaErrors
(
cublasDestroy
(
handle
));
}
}
// C++ interface
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
...
@@ -306,7 +62,6 @@ int main() {
...
@@ -306,7 +62,6 @@ int main() {
torch
::
Tensor
input
=
torch
::
randn
({
2048
,
512
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
input
=
torch
::
randn
({
2048
,
512
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2048
,
2
},
torch
::
dtype
(
torch
::
kInt64
));
torch
::
Tensor
gate
=
torch
::
zeros
({
2048
,
2
},
torch
::
dtype
(
torch
::
kInt64
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
512
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
weight
=
torch
::
randn
({
2
,
512
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
torch
::
Tensor
bias
=
torch
::
randn
({
2
,
2048
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
,
device
));
checkCudaErrors
(
cudaSetDevice
(
device
));
checkCudaErrors
(
cudaSetDevice
(
device
));
moe_cuda_forward
_v1
(
input
,
gate
,
weight
,
bias
);
moe_cuda_forward
(
input
,
gate
,
weight
);
}
}
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
0 → 100644
View file @
2b5672e5
#include <cstdio>
#include <iostream>
#include <vector>
// CUDA runtime
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
// CUDA and CUBLAS functions
//#include <helper_functions.h>
#include <helper_cuda.h>
typedef
float
data_t
;
size_t
batch_size
=
4096
;
size_t
top_k
=
2
;
size_t
num_expert
=
128
;
size_t
in_feat
=
512
;
size_t
out_feat
=
2048
;
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
size_t
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
Aarray
[],
int
lda
,
const
float
*
Barray
[],
int
ldb
,
const
float
*
beta
,
float
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasSgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
Aarray
[],
int
lda
,
const
double
*
Barray
[],
int
ldb
,
const
double
*
beta
,
double
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasDgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
__half
*
alpha
,
const
__half
*
Aarray
[],
int
lda
,
const
__half
*
Barray
[],
int
ldb
,
const
__half
*
beta
,
__half
*
Carray
[],
int
ldc
,
int
batchCount
)
{
return
cublasHgemmBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
Aarray
,
lda
,
Barray
,
ldb
,
beta
,
Carray
,
ldc
,
batchCount
);
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
size_t
*
gate
,
const
scalar_t
*
weight
,
scalar_t
*
output
,
size_t
batch_size
,
size_t
top_k
,
size_t
in_feat
,
size_t
out_feat
)
{
cublasHandle_t
handle
;
cudaStream_t
st
;
cudaStreamCreate
(
&
st
);
checkCudaErrors
(
cublasCreate
(
&
handle
));
checkCudaErrors
(
cublasSetStream
(
handle
,
st
));
// setup Aarray, Barray and Carray
std
::
vector
<
const
scalar_t
*>
aptrs
;
std
::
vector
<
scalar_t
*>
cptrs
;
const
scalar_t
**
Aarray
;
const
scalar_t
**
Barray
;
scalar_t
**
Carray
;
checkCudaErrors
(
cudaMalloc
(
&
Aarray
,
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
));
checkCudaErrors
(
cudaMalloc
(
&
Barray
,
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
));
checkCudaErrors
(
cudaMalloc
(
&
Carray
,
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
k
=
0
;
k
<
top_k
;
++
k
)
{
aptrs
.
push_back
(
input
+
in_feat
*
i
);
// bptrs.push_back(weight + out_feat * in_feat * gate[i * top_k + k]);
cptrs
.
push_back
(
output
+
out_feat
*
(
i
*
top_k
+
k
));
}
}
checkCudaErrors
(
cudaMemcpy
(
Aarray
,
aptrs
.
data
(),
batch_size
*
sizeof
(
const
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors
(
cudaMemcpy
(
Carray
,
cptrs
.
data
(),
batch_size
*
sizeof
(
scalar_t
*
)
*
top_k
,
cudaMemcpyHostToDevice
));
dim3
griddim
(
CEIL
(
batch_size
*
top_k
,
256
));
dim3
blockdim
(
256
);
generate_ptr_offset_kernel
<<<
griddim
,
blockdim
,
0
,
st
>>>
(
batch_size
*
top_k
,
weight
,
out_feat
*
in_feat
,
gate
,
Barray
);
scalar_t
alpha
=
1
,
beta
=
0
;
checkCudaErrors
(
cublasXgemmBatched
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
1
,
out_feat
,
in_feat
,
&
alpha
,
Aarray
,
1
,
Barray
,
out_feat
,
&
beta
,
Carray
,
1
,
batch_size
));
checkCudaErrors
(
cudaStreamSynchronize
(
st
));
checkCudaErrors
(
cudaStreamDestroy
(
st
));
checkCudaErrors
(
cublasDestroy
(
handle
));
}
int
main
()
{
const
data_t
*
input
,
*
weight
;
data_t
*
output
;
const
size_t
*
gate
;
checkCudaErrors
(
cudaMalloc
(
&
input
,
batch_size
*
in_feat
*
sizeof
(
const
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
weight
,
num_expert
*
in_feat
*
out_feat
*
sizeof
(
const
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
output
,
batch_size
*
top_k
*
out_feat
*
sizeof
(
data_t
)));
checkCudaErrors
(
cudaMalloc
(
&
gate
,
batch_size
*
top_k
*
sizeof
(
size_t
)));
moe_cuda_forward_impl
<
data_t
>
(
input
,
gate
,
weight
,
output
,
batch_size
,
top_k
,
in_feat
,
out_feat
);
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment