Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3a41edb8
Unverified
Commit
3a41edb8
authored
Sep 11, 2023
by
Rick Ho
Committed by
GitHub
Sep 11, 2023
Browse files
Merge pull request #172 from laekov/smgr_bug
[BUG FIX] Fix bugs in stream manager.
parents
c1c19f3e
1f82fb16
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
74 additions
and
46 deletions
+74
-46
cuda/balancing.cu
cuda/balancing.cu
+6
-2
cuda/balancing.cuh
cuda/balancing.cuh
+2
-4
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+1
-2
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+39
-23
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+2
-3
cuda/global_exchange.h
cuda/global_exchange.h
+4
-6
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+2
-4
cuda/parallel_linear.cuh
cuda/parallel_linear.cuh
+2
-0
cuda/stream_manager.cpp
cuda/stream_manager.cpp
+13
-1
cuda/stream_manager.h
cuda/stream_manager.h
+2
-0
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+1
-1
No files found.
cuda/balancing.cu
View file @
3a41edb8
...
@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
}
}
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
smgr
->
syncTorch
();
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
/* Limit number of incoming samples */
/* Limit number of incoming samples */
...
@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */
/* Send limit information back */
_h2d
(
gec
,
d_gec
,
n_worker
);
_h2d
(
gec
,
d_gec
,
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_worker
,
smgr
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_worker
,
smgr
);
smgr
->
syncTorch
();
_d2h
(
d_lec
,
lec
,
n_worker
);
_d2h
(
d_lec
,
lec
,
n_worker
);
auto
d_dropcount
=
_h2d
(
drop_count
,
n_worker
);
auto
d_dropcount
=
_h2d
(
drop_count
,
n_worker
);
ncclAllReduce
(
d_dropcount
,
d_dropcount
,
n_worker
,
ncclInt64
,
ncclSum
,
ncclAllReduce
(
d_dropcount
,
d_dropcount
,
n_worker
,
ncclInt64
,
ncclSum
,
smgr
->
ncclcomm
,
smgr
->
stream
());
smgr
->
ncclcomm
,
smgr
->
torchStream
());
smgr
->
syncTorch
();
_d2h
(
d_dropcount
,
drop_count
,
n_worker
);
_d2h
(
d_dropcount
,
drop_count
,
n_worker
);
auto
d_gcap
=
_cudamalloc
<
long
>
(
n_worker
);
auto
d_gcap
=
_cudamalloc
<
long
>
(
n_worker
);
_h2d
(
&
cap
,
d_gcap
+
rank
,
1
);
_h2d
(
&
cap
,
d_gcap
+
rank
,
1
);
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
smgr
->
ncclcomm
,
smgr
->
stream
());
smgr
->
ncclcomm
,
smgr
->
torchStream
());
smgr
->
syncTorch
();
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
/* Re-assign and update counters */
/* Re-assign and update counters */
...
...
cuda/balancing.cuh
View file @
3a41edb8
...
@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
...
@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
n_worker
,
1024
),
n_expert
);
dim3
grid_dim
(
CEIL
(
n_worker
,
1024
),
n_expert
);
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
limit_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
s
tream
(
0
)
>>>
(
limit_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
torchS
tream
()
>>>
(
ec
,
cap
,
eca
,
n_expert
,
n_worker
);
ec
,
cap
,
eca
,
n_expert
,
n_worker
);
smgr
->
sync
(
1
);
}
}
__global__
__global__
...
@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
...
@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
s
tream
(
0
)
>>>
(
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
torchS
tream
()
>>>
(
gate_idx
,
new_gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
gate_idx
,
new_gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
);
);
smgr
->
sync
(
1
);
}
}
cuda/fastermoe/smart_schedule.cpp
View file @
3a41edb8
...
@@ -44,10 +44,9 @@ void _reduce_grad(
...
@@ -44,10 +44,9 @@ void _reduce_grad(
long
expert_size
)
{
long
expert_size
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaEvent_t
evt_stash
;
cudaEvent_t
evt_stash
;
cudaEventCreate
(
&
evt_stash
);
cudaEventCreate
(
&
evt_stash
);
cudaEventRecord
(
evt_stash
,
torch
_s
tream
);
cudaEventRecord
(
evt_stash
,
smgr
->
torch
S
tream
()
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_stash
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_stash
);
cudaEventDestroy
(
evt_stash
);
cudaEventDestroy
(
evt_stash
);
...
...
cuda/fastermoe/smart_schedule.h
View file @
3a41edb8
...
@@ -122,7 +122,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -122,7 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long
d_model
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
expert_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
expert_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
smgr
->
syncTorch
();
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_torch_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_torch_ready
+
i
);
}
}
// S_0 ... S_n
// S_0 ... S_n
...
@@ -157,11 +159,11 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -157,11 +159,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
num_expert
));
}
}
// Broadcast shadowed experts
// Broadcast shadowed experts
...
@@ -173,22 +175,23 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -173,22 +175,23 @@ void fmoe_cuda_fused_forward_impl(
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_
stream
);
cudaEventRecord
(
evt_get
,
smgr
->
stream
(
0
)
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_get
);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
evt_get
);
cudaEventDestroy
(
evt_get
);
cudaEventDestroy
(
evt_get
);
}
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
num_expert
)));
cudaEventCreate
(
evt_shadow
+
si
);
cudaEventCreate
(
evt_shadow
+
si
);
cudaEventRecord
(
evt_shadow
[
si
],
smgr
->
stream
(
0
));
cudaEventRecord
(
evt_shadow
[
si
],
smgr
->
stream
(
num_expert
));
++
si
;
++
si
;
}
}
}
}
// C_0 ... C_n
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
torch_stream
,
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
torchStream
(),
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
@@ -198,13 +201,15 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -198,13 +201,15 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf
,
global_output_buf
,
global_input_buf
,
global_output_buf
,
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
smgr
->
torchStream
());
}
}
// Compute over shadowed experts
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_shadow
[
si
]);
FMOE_SWE
(
smgr
->
torchStream
(),
evt_shadow
[
si
]);
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
long
offset
=
local_ptr
[
i
];
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
...
@@ -218,7 +223,8 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -218,7 +223,8 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_torch_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
@@ -230,12 +236,12 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -230,12 +236,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
}
smgr
->
sync
(
1
);
smgr
->
sync
(
num_expert
+
1
);
delete
[]
local_ptr
;
delete
[]
local_ptr
;
delete
[]
global_ptr
;
delete
[]
global_ptr
;
...
@@ -244,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -244,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_torch_ready
[
i
]);
}
}
for
(
unsigned
i
=
0
;
i
<
params
.
size
();
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
params
.
size
();
++
i
)
{
cudaEventDestroy
(
evt_shadow
[
i
]);
cudaEventDestroy
(
evt_shadow
[
i
]);
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
delete
[]
output_torch_ready
;
}
}
...
@@ -273,7 +281,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -273,7 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long
d_model
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
smgr
->
syncTorch
();
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
@@ -290,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -290,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_torch_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_torch_ready
+
i
);
}
}
// S_0 ... S_n
// S_0 ... S_n
...
@@ -308,11 +318,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -308,11 +318,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
num_expert
));
}
}
// Shadowed experts backward and reduce
// Shadowed experts backward and reduce
...
@@ -328,7 +338,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -328,7 +338,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn
(
si
,
i
/
num_expert
,
0
);
collect_fn
(
si
,
i
/
num_expert
,
0
);
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
num_expert
));
}
}
++
si
;
++
si
;
}
}
...
@@ -337,7 +347,8 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -337,7 +347,8 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
torch_stream
,
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
torchStream
(),
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
@@ -348,14 +359,16 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -348,14 +359,16 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out
,
global_grad_in
,
global_grad_out
,
global_grad_in
,
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
smgr
->
torchStream
());
}
}
// Collect gradients for shadowed experts
// Collect gradients for shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
smgr
->
torchStream
(),
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
,
i
%
num_expert
);
set_grad_fn
(
si
,
i
%
num_expert
);
}
}
++
si
;
++
si
;
...
@@ -364,7 +377,8 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -364,7 +377,8 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_torch_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
@@ -376,13 +390,13 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -376,13 +390,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
}
smgr
->
sync
(
1
);
smgr
->
sync
(
num_expert
+
1
);
checkCudaErrors
(
cudaGetLastError
());
checkCudaErrors
(
cudaGetLastError
());
delete
[]
local_ptr
;
delete
[]
local_ptr
;
...
@@ -392,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -392,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_torch_ready
[
i
]);
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
delete
[]
output_torch_ready
;
for
(
long
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
long
i
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
stored_models
[
i
+
rank
*
num_expert
])
{
if
(
stored_models
[
i
+
rank
*
num_expert
])
{
cudaEventDestroy
(
evt_reduce
[
i
]);
cudaEventDestroy
(
evt_reduce
[
i
]);
...
...
cuda/global_exchange.cpp
View file @
3a41edb8
...
@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
...
@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64
,
ncclInt64
,
i
,
i
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
NCCL_SAFE_CALL
(
ncclRecv
(
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
global_expert_count
+
n_expert
*
i
,
n_expert
,
n_expert
,
ncclInt64
,
ncclInt64
,
i
,
i
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
}
torch
::
Tensor
_expert_exchange
(
torch
::
Tensor
_expert_exchange
(
...
...
cuda/global_exchange.h
View file @
3a41edb8
...
@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
...
@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
if
(
global_expert_count
[
idx
])
{
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
NCCL_SAFE_CALL
(
ncclRecv
(
...
@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
...
@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
recv_ptr
+=
global_expert_count
[
idx
];
recv_ptr
+=
global_expert_count
[
idx
];
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
...
@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
send_ptr
+=
global_expert_count
[
idx
];
send_ptr
+=
global_expert_count
[
idx
];
}
}
if
(
local_expert_count
[
idx
])
{
if
(
local_expert_count
[
idx
])
{
...
@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
...
@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
...
...
cuda/local_exchange.cuh
View file @
3a41edb8
...
@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
...
@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
size_t
numel
=
batch_size
*
topk
;
size_t
numel
=
batch_size
*
topk
;
assign_pos_kernel
assign_pos_kernel
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
s
tream
(
0
)
>>>
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
torchS
tream
()
>>>
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
smgr
->
sync
(
1
);
}
}
#define PERTHREAD_EXPERTS 256
#define PERTHREAD_EXPERTS 256
...
@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
...
@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const
size_t
batch_size
,
const
size_t
n_expert
,
const
size_t
batch_size
,
const
size_t
n_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
expert_count_kernel
expert_count_kernel
<<<
CEIL
(
n_expert
,
PERTHREAD_EXPERTS
),
256
,
0
,
smgr
->
s
tream
(
0
)
>>>
<<<
CEIL
(
n_expert
,
PERTHREAD_EXPERTS
),
256
,
0
,
smgr
->
torchS
tream
()
>>>
(
gate_idx
,
expert_count
,
batch_size
,
n_expert
);
(
gate_idx
,
expert_count
,
batch_size
,
n_expert
);
smgr
->
sync
(
1
);
}
}
cuda/parallel_linear.cuh
View file @
3a41edb8
...
@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
...
@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
smgr
->
syncTorch
();
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
continue
;
...
@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
...
@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
smgr
->
syncTorch
();
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
// bias
// bias
...
...
cuda/stream_manager.cpp
View file @
3a41edb8
...
@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
...
@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
}
}
cudaStream_t
CudaStreamManager
::
torchStream
()
{
return
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
}
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
if
(
this
->
use_default
)
{
if
(
this
->
use_default
)
{
return
at
::
cuda
::
getCurrentCUDABlasHandle
();
return
at
::
cuda
::
getCurrentCUDABlasHandle
();
...
@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
...
@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
}
}
void
CudaStreamManager
::
syncTorch
()
{
cudaStreamSynchronize
(
this
->
torchStream
());
}
void
CudaStreamManager
::
sync
(
int
idx
)
{
void
CudaStreamManager
::
sync
(
int
idx
)
{
if
(
this
->
use_default
)
{
if
(
this
->
use_default
)
{
return
;
return
;
...
@@ -45,7 +53,11 @@ void CudaStreamManager::setup(const int device) {
...
@@ -45,7 +53,11 @@ void CudaStreamManager::setup(const int device) {
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
for
(
size_t
i
=
0
;
i
<
SMGR_N_STREAMS
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
SMGR_N_STREAMS
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
// SHOULD NOT USE: cudaStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors
(
cudaStreamCreateWithFlags
(
streams
+
i
,
cudaStreamNonBlocking
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
...
...
cuda/stream_manager.h
View file @
3a41edb8
...
@@ -34,8 +34,10 @@ public:
...
@@ -34,8 +34,10 @@ public:
void
setup
(
int
);
void
setup
(
int
);
void
sync
(
int
=
0
);
void
sync
(
int
=
0
);
void
syncTorch
();
void
destroy
();
void
destroy
();
cudaStream_t
torchStream
();
cudaStream_t
stream
(
size_t
=
0
);
cudaStream_t
stream
(
size_t
=
0
);
cublasHandle_t
handle
(
size_t
=
0
);
cublasHandle_t
handle
(
size_t
=
0
);
...
...
fmoe/fastermoe/schedule.py
View file @
3a41edb8
...
@@ -37,7 +37,7 @@ class MoEForward(Function):
...
@@ -37,7 +37,7 @@ class MoEForward(Function):
try
:
try
:
# To skip torch autograd's version check.
# To skip torch autograd's version check.
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
nothing
,
nothing
):
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
nothing
,
nothing
):
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
))
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
)
,
expert_idx
)
except
Exception
as
e
:
except
Exception
as
e
:
# Ignore the error and fall back for compatibility to older
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
# versions of PyTorch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment