Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
49732231
Commit
49732231
authored
Jan 09, 2021
by
Rick Ho
Browse files
split operators and make forward run
parent
143e21cc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
226 additions
and
92 deletions
+226
-92
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+21
-0
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+8
-16
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+45
-9
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+8
-4
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+144
-63
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
49732231
...
@@ -13,3 +13,24 @@ void CudaStreamManager::sync(int i) {
...
@@ -13,3 +13,24 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
}
}
}
}
void
CudaStreamManager
::
setup
(
const
size_t
num_expert
,
const
int
device
)
{
#ifdef MOE_DEBUG
printf
(
"setup at device %d
\n
"
,
device
);
#endif
this
->
num_expert
=
num_expert
;
if
(
device
==
-
1
)
{
checkCudaErrors
(
cudaGetDevice
(
&
this
->
device
));
}
else
{
this
->
device
=
device
;
}
checkCudaErrors
(
cudaSetDevice
(
this
->
device
));
streams
=
new
cudaStream_t
[
num_expert
];
handles
=
new
cublasHandle_t
[
num_expert
];
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
pytorch/cuda/cuda_stream_manager.h
View file @
49732231
...
@@ -16,7 +16,7 @@ public:
...
@@ -16,7 +16,7 @@ public:
cudaStream_t
*
streams
;
cudaStream_t
*
streams
;
public:
public:
CudaStreamManager
()
:
num_expert
(
0
),
device
(
0
),
streams
(
NULL
)
{
CudaStreamManager
()
:
num_expert
(
0
),
streams
(
NULL
)
{
int
current_device
;
int
current_device
;
checkCudaErrors
(
cudaGetDevice
(
&
current_device
));
checkCudaErrors
(
cudaGetDevice
(
&
current_device
));
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
...
@@ -24,21 +24,7 @@ public:
...
@@ -24,21 +24,7 @@ public:
#endif
#endif
}
}
void
setup
(
const
size_t
num_expert
,
const
int
device
)
{
void
setup
(
const
size_t
num_expert
,
const
int
device
=-
1
);
#ifdef MOE_DEBUG
printf
(
"setup at device %d
\n
"
,
device
);
#endif
this
->
num_expert
=
num_expert
;
this
->
device
=
device
;
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
num_expert
];
handles
=
new
cublasHandle_t
[
num_expert
];
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
~
CudaStreamManager
()
{
~
CudaStreamManager
()
{
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
...
@@ -54,6 +40,12 @@ public:
...
@@ -54,6 +40,12 @@ public:
void
sync
(
int
=-
1
);
void
sync
(
int
=-
1
);
};
};
#define ENSURE_SMGR(__smgr__, __num_expert__) { \
if (__smgr__.num_expert == 0) { \
__smgr__.setup(__num_expert__); \
} \
}
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER
#endif // CUDA_STREAM_MANAGER
pytorch/cuda/moe.cpp
View file @
49732231
...
@@ -4,10 +4,22 @@
...
@@ -4,10 +4,22 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
weight
,
// TODO: pass num-experts in another way?
torch
::
Tensor
gate
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
pos
);
torch
::
Tensor
weight
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
grad_output
,
...
@@ -22,20 +34,41 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -22,20 +34,41 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
weight
,
torch
::
Tensor
gate
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
weight
,
gate
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
input
_buf
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size
]
torch
::
Tensor
weight
,
// [num_expert x hidden_feat x in_feat
]
torch
::
Tensor
weight
// [num_expert x hidden_feat x in_feat
]
torch
::
Tensor
expert_count
// [batch_size
]
)
{
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
return
moe_cuda_forward
(
input
,
gate
,
weight
);
return
moe_cuda_forward
(
input
_buf
,
weight
,
expert_count
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
...
@@ -69,6 +102,9 @@ int main() {
...
@@ -69,6 +102,9 @@ int main() {
*/
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"expert_count"
,
&
moe_expert_count
,
"MoE expert count (CUDA)"
);
m
.
def
(
"local_scatter"
,
&
moe_local_scatter
,
"MoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
moe_local_gather
,
"MoE local gather (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
}
pytorch/cuda/moe.py
View file @
49732231
...
@@ -11,8 +11,12 @@ class MOEFunction(Function):
...
@@ -11,8 +11,12 @@ class MOEFunction(Function):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
# out_feat, in_feat = weight.size()[1:]
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
)
expert_count
,
pos
=
moe_cuda
.
expert_count
(
weight
,
gate
)
variables
=
[
inp
,
gate
,
weight
]
input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
inp
,
gate
,
weight
,
expert_count
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
...
@@ -138,5 +142,5 @@ def test_dp():
...
@@ -138,5 +142,5 @@ def test_dp():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# test()
test
()
test_dp
()
# test_dp()
\ No newline at end of file
pytorch/cuda/moe_cuda_kernel.cu
View file @
49732231
...
@@ -36,7 +36,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...
@@ -36,7 +36,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_scatter_kernel
(
in
t
wid
,
int
*
pos
,
void
batch_scatter_kernel
(
size_
t
wid
,
const
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
...
@@ -45,39 +45,14 @@ void batch_scatter_kernel(int wid, int* pos,
...
@@ -45,39 +45,14 @@ void batch_scatter_kernel(int wid, int* pos,
}
}
}
}
template
<
typename
scalar_t
>
void
moe_cuda_expert_count_impl
(
__global__
void
batch_gather_kernel
(
int
wid
,
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
int
*
d_gate
,
const
int
*
d_gate
,
const
scalar_t
*
weight
,
int
*
expert_count
,
scalar_t
*
output
,
int
*
d_pos
,
const
size_t
batch_size
,
const
size_t
num_expert
,
const
size_t
in_feat
,
const
size_t
batch_size
)
{
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
scalar_t
*
input_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
int
*
gate
=
new
int
[
batch_size
];
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_count
=
new
int
[
num_expert
],
*
expert_ptr
=
new
int
[
num_expert
];
int
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
...
@@ -92,19 +67,65 @@ void moe_cuda_forward_impl(
...
@@ -92,19 +67,65 @@ void moe_cuda_forward_impl(
}
}
int
*
pos
=
new
int
[
batch_size
];
int
*
pos
=
new
int
[
batch_size
];
int
*
d_pos
;
checkCudaErrors
(
cudaMalloc
(
&
d_pos
,
sizeof
(
int
)
*
batch_size
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
}
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
expert_ptr
;
ENSURE_SMGR
(
smgr
,
num_expert
);
}
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
int
*
d_pos
,
scalar_t
*
input_buf
,
const
size_t
batch_size
,
const
size_t
in_feat
)
{
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
input_buf
);
smgr
.
sync
(
0
);
smgr
.
sync
(
0
);
}
template
<
typename
scalar_t
>
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
int
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
)
{
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
.
sync
(
0
);
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
scalar_t
*
output_buf
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
...
@@ -126,17 +147,7 @@ void moe_cuda_forward_impl(
...
@@ -126,17 +147,7 @@ void moe_cuda_forward_impl(
ptr
+=
expert_count
[
i
];
ptr
+=
expert_count
[
i
];
}
}
smgr
.
sync
();
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
.
sync
(
0
);
cudaFree
(
input_buf
);
cudaFree
(
output_buf
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
gate
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -176,37 +187,107 @@ void moe_cuda_grad_weight(
...
@@ -176,37 +187,107 @@ void moe_cuda_grad_weight(
delete
[]
gate_host
;
delete
[]
gate_host
;
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
weight
,
torch
::
Tensor
gate
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
pos_options
=
torch
::
TensorOptions
()
.
device
(
gate
.
device
())
.
dtype
(
torch
::
kInt32
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
moe_cuda_expert_count_impl
(
gate
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
num_expert
,
batch_size
);
return
{
expert_count
,
pos
};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
in_feat
=
input
.
size
(
1
);
auto
input_buf
=
torch
::
empty_like
(
input
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
);
}));
return
{
input_buf
,};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
const
auto
batch_size
=
output_buf
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
output
=
torch
::
empty_like
(
output_buf
);
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
);
}));
return
{
output
,};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
input
_buf
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
,
torch
::
Tensor
weigh
t
torch
::
Tensor
expert_coun
t
)
{
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
/*
const int device = device_of(input).value().index();
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
smgr.setup(num_expert, device);
}
}
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
*/
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
_buf
.
scalar_type
(),
"moe_forward_cuda"
,
moe_cuda_forward_impl
<
scalar_t
>
(
([
&
]
{
input
.
data_ptr
<
scalar_t
>
(
),
moe_cuda_forward_impl
<
scalar_t
>
(
gate
.
data_ptr
<
in
t
>
(),
input_buf
.
data_ptr
<
scalar_
t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
outpu
t
.
data_ptr
<
scalar_
t
>
(),
expert_coun
t
.
data_ptr
<
in
t
>
(),
batch_size
,
output
.
data_ptr
<
scalar_t
>
()
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
,
CUBLAS_OP_T
CUBLAS_OP_T
);
);
}));
}));
return
{
output
,
};
return
{
output
,
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment