Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
771dc62d
Commit
771dc62d
authored
Mar 30, 2022
by
Rick Ho
Browse files
forward code in smart schedule
parent
ad651f03
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
206 additions
and
54 deletions
+206
-54
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+27
-3
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+62
-36
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+11
-3
cuda/stream_manager.cpp
cuda/stream_manager.cpp
+13
-0
cuda/stream_manager.h
cuda/stream_manager.h
+2
-1
fmoe/fastermoe/expert_utils.py
fmoe/fastermoe/expert_utils.py
+56
-0
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+31
-9
fmoe/layers.py
fmoe/layers.py
+4
-2
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
771dc62d
...
...
@@ -25,8 +25,12 @@ std::vector<torch::Tensor> _smart_sch_forward(
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
stored_models
,
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
py
::
function
forward_fn
)
{
py
::
function
forward_fn
,
py
::
function
get_param_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
)
{
if
(
pipeline_gran
==
-
1
)
{
char
*
p
=
getenv
(
"FMOE_FASTER_GROUP_SIZE"
);
if
(
p
)
{
...
...
@@ -50,11 +54,26 @@ std::vector<torch::Tensor> _smart_sch_forward(
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
std
::
vector
<
torch
::
Tensor
>
params
;
auto
stored_models_
=
stored_models
.
data_ptr
<
bool
>
();
for
(
long
i
=
0
;
i
<
num_expert
*
n_workers
;
++
i
)
{
if
(
stored_models_
[
i
])
{
torch
::
Tensor
t
=
input_buf
.
new_empty
({
expert_size
});
if
(
i
/
num_expert
==
rank
)
{
get_param_fn
(
t
);
}
params
.
push_back
(
t
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"fmoe_cuda_smart_sch_forward"
,
([
&
]
{
fmoe_cuda_fused_forward_impl
(
forward_fn
,
stash_fn
,
pop_fn
,
input_buf
.
device
(),
params
,
input_buf
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -64,7 +83,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
stored_models
.
data_ptr
<
bool
>
(),
d_model
,
num_expert
,
rank
,
n_workers
,
d_model
,
num_expert
,
rank
,
n_workers
,
expert_size
,
pipeline_gran
,
smgr
);
}));
return
{
output_buf
,
global_input_buf
};
...
...
@@ -77,8 +96,13 @@ torch::Tensor _smart_sch_backward(
torch
::
Tensor
stored_models
,
long
buf_batch_size
,
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
py
::
function
backward_fn
)
{
py
::
function
backward_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
,
py
::
function
collect_fn
,
py
::
function
set_grad_fn
)
{
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
smgr
=
getCudaStreamManager
(
grad_out
.
device
().
index
());
int
rank
;
...
...
cuda/fastermoe/smart_schedule.h
View file @
771dc62d
...
...
@@ -39,6 +39,7 @@ void _exchange_with(
int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert;
void
_compute_ptrs
(
long
num_expert
,
long
rank
,
long
world_size
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
...
...
@@ -73,10 +74,11 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
}
}
template
<
typename
scalar_t
>
void
_compute_fn
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
long
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
CudaStreamManager
*
smgr
)
{
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
c10
::
CppTypeToScalarType
<
scalar_t
>::
value
)
...
...
@@ -87,7 +89,7 @@ void _compute_fn(py::function fn, c10::Device device,
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
smgr
->
use_default
=
true
;
fn
(
inp
,
oup
,
step
);
fn
(
inp
,
oup
,
idx
);
smgr
->
use_default
=
false
;
}
...
...
@@ -95,9 +97,12 @@ void _compute_fn(py::function fn, c10::Device device,
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_forward_impl
(
py
::
function
forward_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
,
c10
::
Device
device
,
std
::
vector
<
torch
::
Tensor
>
params
,
const
scalar_t
*
input_buf
,
scalar_t
*
input_buf
,
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
...
...
@@ -107,8 +112,9 @@ void fmoe_cuda_fused_forward_impl(
const
bool
*
stored_models
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
expert_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
...
@@ -130,8 +136,9 @@ void fmoe_cuda_fused_forward_impl(
cudaEventCreate
(
output_ready
+
i
);
}
// S_0 ... S_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
long
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
pipeline_gran
;
++
j
)
{
...
...
@@ -149,8 +156,30 @@ void fmoe_cuda_fused_forward_impl(
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
}
// Broadcast shadowed experts
cudaEvent_t
evt_get
,
*
evt_shadow
;
if
(
params
.
size
()
>
0
)
{
evt_shadow
=
new
cudaEvent_t
[
params
.
size
()];
}
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
}
NCCL_SAFE_CALL
(
ncclBcast
(
params
[
si
].
data_ptr
<
void
>
(),
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
cudaEventCreate
(
evt_shadow
+
si
);
cudaEventRecord
(
evt_shadow
[
si
],
smgr
->
stream
(
0
));
++
si
;
}
}
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
1
)
,
input_ready
[
step
],
0
);
cudaStreamWaitEvent
(
torch_
stream
,
input_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
...
@@ -159,12 +188,27 @@ void fmoe_cuda_fused_forward_impl(
_compute_fn
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
stash_fn
(
params
[
si
],
si
);
cudaStreamWaitEvent
(
torch_stream
,
evt_shadow
[
si
],
0
);
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
_compute_fn
(
forward_fn
,
device
,
input_buf
,
output_buf
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
++
si
;
}
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaEventRecord
(
output_ready
[
step
],
stream
);
}
pop_fn
();
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
...
...
@@ -184,31 +228,6 @@ void fmoe_cuda_fused_forward_impl(
}
}
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_forward(
input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden,
smgr->stream(stream), smgr->handle(stream));
}
}*/
delete
[]
local_ptr
;
delete
[]
global_ptr
;
delete
[]
local_global_ptr
;
...
...
@@ -217,8 +236,14 @@ void fmoe_cuda_fused_forward_impl(
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
}
for
(
unsigned
i
=
0
;
i
<
params
.
size
();
++
i
)
{
cudaEventDestroy
(
evt_shadow
[
i
]);
}
delete
[]
input_ready
;
delete
[]
output_ready
;
if
(
params
.
size
())
{
delete
[]
evt_shadow
;
}
}
...
...
@@ -238,6 +263,7 @@ void fmoe_cuda_fused_backward_impl(
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
...
@@ -289,9 +315,9 @@ void fmoe_cuda_fused_backward_impl(
_compute_fn
(
backward_fn
,
device
,
global_grad_out
,
global_grad_in
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
// TODO: get pytorch's compute
stream
cudaEventRecord
(
output_ready
[
step
],
torch_
stream
);
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
...
...
cuda/fmoe_cuda.cpp
View file @
771dc62d
...
...
@@ -63,8 +63,13 @@ std::vector<torch::Tensor> _smart_sch_forward(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
stored_models
,
long
global_batch_size
,
long
n_workers
,
py
::
function
forward_fn
);
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
py
::
function
forward_fn
,
py
::
function
get_param_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
);
torch
::
Tensor
_smart_sch_backward
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
local_expert_count
,
...
...
@@ -72,8 +77,11 @@ torch::Tensor _smart_sch_backward(
torch
::
Tensor
stored_models
,
long
buf_batch_size
,
long
global_batch_size
,
long
expert_size
,
long
n_workers
,
py
::
function
backward_fn
);
py
::
function
backward_fn
,
py
::
function
collect_fn
,
py
::
function
set_grad_fn
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
...
...
cuda/stream_manager.cpp
View file @
771dc62d
...
...
@@ -3,21 +3,34 @@
#include <cassert>
#include <thread>
#include <iostream>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "fastermoe/status.h"
#include "stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
if
(
this
->
use_default
)
{
return
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
}
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
}
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
if
(
this
->
use_default
)
{
return
at
::
cuda
::
getCurrentCUDABlasHandle
();
}
return
this
->
handles
[
idx
%
SMGR_N_STREAMS
];
}
void
CudaStreamManager
::
sync
(
int
idx
)
{
if
(
this
->
use_default
)
{
return
;
}
for
(
int
i
=
0
;
i
<
idx
&&
i
<
SMGR_N_STREAMS
;
++
i
)
{
cudaStreamSynchronize
(
streams
[
i
]);
}
...
...
cuda/stream_manager.h
View file @
771dc62d
...
...
@@ -21,13 +21,14 @@ public:
int
device
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
bool
use_default
;
#ifdef FMOE_USE_NCCL
char
ncclgood
;
ncclComm_t
ncclcomm
;
#endif
public:
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
,
use_default
(
false
)
{
this
->
setup
(
device
);
}
...
...
fmoe/fastermoe/expert_utils.py
0 → 100644
View file @
771dc62d
import
torch
def
get_expert_param_size
(
e
):
return
sum
(
map
(
lambda
x
:
x
.
numel
(),
e
.
parameters
()))
def
get_expert_params
(
e
,
out
):
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
seg
.
copy_
(
p
)
def
stash_expert_params
(
e
,
params
):
if
not
hasattr
(
e
,
'expert_param_stash'
):
setattr
(
e
,
'expert_param_stash'
,
dict
())
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
if
n
not
in
e
.
expert_param_stash
:
e
.
expert_param_stash
[
n
]
=
p
.
data
.
clone
()
with
torch
.
no_grad
():
seg
=
params
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
p
.
copy_
(
seg
.
reshape
(
p
.
shape
))
def
pop_expert_params
(
e
):
for
n
,
p
in
e
.
named_parameters
():
with
torch
.
no_grad
():
p
.
copy_
(
e
.
expert_param_stash
[
n
])
e
.
expert_param_stash
.
clear
()
def
collect_expert_grads
(
e
,
grads
):
offset
=
0
for
_
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
if
p
.
grad
is
not
None
:
seg
.
copy_
(
p
.
grad
)
p
.
grad
=
None
else
:
seg
.
zero_
()
def
set_grads
(
e
,
grads
):
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
if
p
.
grad
is
None
:
p
.
grad
=
seg
.
clone
()
else
:
p
.
grad
+=
seg
fmoe/fastermoe/schedule.py
View file @
771dc62d
...
...
@@ -7,6 +7,7 @@ from torch.autograd.function import Function
from
fmoe.functions
import
prepare_forward
,
ensure_comm
from
fmoe.functions
import
_local_scatter
,
_local_gather
import
fmoe_cuda
as
fmoe_native
import
expert_utils
class
MoEForward
(
Function
):
...
...
@@ -14,6 +15,7 @@ class MoEForward(Function):
def
forward
(
ctx
,
expert_fn
,
experts
,
inp
,
# models,
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
...
...
@@ -25,8 +27,8 @@ class MoEForward(Function):
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx
.
gibs
=
[
None
]
*
world_size
ctx
.
gobs
=
[
None
]
*
world_size
ctx
.
gibs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
2
)
def
_expert_forward
(
x
,
y
,
idx
):
x
=
x
.
data
with
torch
.
enable_grad
():
...
...
@@ -36,11 +38,23 @@ class MoEForward(Function):
ctx
.
gobs
[
idx
]
=
y0
y
.
copy_
(
y0
)
ctx
.
experts
=
experts
if
stored_models
.
any
():
ctx
.
expert_size
=
expert_utils
.
get_expert_param_size
(
experts
)
else
:
ctx
.
expert_size
=
0
get_param_fn
=
lambda
out
:
expert_utils
.
get_expert_params
(
experts
,
out
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
ctx
.
shadows
=
[
None
]
*
world_size
def
stash_fn
(
params
,
idx
):
expert_utils
.
stash_expert_params
(
experts
,
p
)
ctx
.
shadows
[
idx
]
=
params
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
world_size
,
_expert_forward
)
stored_models
,
fwd_batch_size
,
ctx
.
expert_size
,
world_size
,
_expert_forward
,
get_param_fn
,
stash_fn
,
pop_fn
)
out
=
_local_gather
(
local_output_buf
,
pos_g
,
out_batch_size
,
maybe_overlap
=
False
)
...
...
@@ -65,19 +79,27 @@ class MoEForward(Function):
x
=
ctx
.
gibs
[
idx
]
grad_x
.
copy_
(
x
.
grad
)
experts
=
ctx
.
experts
def
stash_fn
(
idx
):
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
idx
])
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
collect_fn
=
lambda
g
:
expert_utils
.
collect_expert_grads
(
experts
,
g
)
set_grad_fn
=
lambda
g
:
expert_utils
.
set_grads
(
experts
,
g
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
stored_models
,
pos_s
.
shape
[
0
],
fwd_batch_size
,
world_size
,
_expert_backward
)
pos_s
.
shape
[
0
],
fwd_batch_size
,
ctx
.
expert_size
,
world_size
,
_expert_backward
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
return
(
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
):
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
# TODO: Support many experts on each process
...
...
@@ -98,7 +120,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
topk
=
gate
.
shape
[
1
]
out_batch_size
=
inp
.
shape
[
0
]
*
topk
return
MoEForward
.
apply
(
expert_fn
,
inp
,
return
MoEForward
.
apply
(
expert_fn
,
experts
,
inp
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
),
pos
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
world_size
)
fmoe/layers.py
View file @
771dc62d
...
...
@@ -21,7 +21,7 @@ def mark_module_parallel_comm(module, comm):
setattr
(
p
,
"dp_comm"
,
comm
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
num_expert
,
world_size
,
**
kwargs
):
r
"""
A private function that performs the following steps to complete the MoE
computation.
...
...
@@ -227,7 +227,9 @@ class FMoE(nn.Module):
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
,
experts
=
self
.
experts
)
# recover deleted tensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment