Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
c8633740
"doc/vscode:/vscode.git/clone" did not exist on "23ceb1c049711f3fa0815474e60b96688ebc367b"
Commit
c8633740
authored
Aug 25, 2023
by
zms1999
Browse files
[BUG FIX] make smart scheduling great again, fix bugs in streams management
parent
c1c19f3e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
18 deletions
+26
-18
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+21
-17
cuda/stream_manager.cpp
cuda/stream_manager.cpp
+5
-1
No files found.
cuda/fastermoe/smart_schedule.h
View file @
c8633740
...
...
@@ -157,11 +157,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
num_expert
));
}
// Broadcast shadowed experts
...
...
@@ -173,21 +173,22 @@ void fmoe_cuda_fused_forward_impl(
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_
stream
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_get
);
cudaEventRecord
(
evt_get
,
smgr
->
stream
(
0
)
);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
evt_get
);
cudaEventDestroy
(
evt_get
);
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
num_expert
)));
cudaEventCreate
(
evt_shadow
+
si
);
cudaEventRecord
(
evt_shadow
[
si
],
smgr
->
stream
(
0
));
cudaEventRecord
(
evt_shadow
[
si
],
smgr
->
stream
(
num_expert
));
++
si
;
}
}
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
torch_stream
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
...
...
@@ -198,12 +199,13 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf
,
global_output_buf
,
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
cudaEventRecord
(
output_ready
[
step
],
torch_
stream
);
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
)
);
}
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_shadow
[
si
]);
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
long
offset
=
local_ptr
[
i
];
...
...
@@ -218,7 +220,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
...
@@ -230,12 +232,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
smgr
->
sync
(
1
);
smgr
->
sync
(
num_expert
+
1
);
delete
[]
local_ptr
;
delete
[]
global_ptr
;
...
...
@@ -308,11 +310,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
num_expert
));
}
// Shadowed experts backward and reduce
...
...
@@ -328,7 +330,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn
(
si
,
i
/
num_expert
,
0
);
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
num_expert
));
}
++
si
;
}
...
...
@@ -337,6 +339,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
torch_stream
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
...
...
@@ -348,13 +351,14 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out
,
global_grad_in
,
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
cudaEventRecord
(
output_ready
[
step
],
torch_
stream
);
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
)
);
}
// Collect gradients for shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
,
i
%
num_expert
);
}
...
...
@@ -364,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
...
@@ -376,13 +380,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
d_model
,
smgr
->
stream
(
num_expert
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
smgr
->
sync
(
1
);
smgr
->
sync
(
num_expert
+
1
);
checkCudaErrors
(
cudaGetLastError
());
delete
[]
local_ptr
;
...
...
cuda/stream_manager.cpp
View file @
c8633740
...
...
@@ -45,7 +45,11 @@ void CudaStreamManager::setup(const int device) {
streams
=
new
cudaStream_t
[
SMGR_N_STREAMS
];
handles
=
new
cublasHandle_t
[
SMGR_N_STREAMS
];
for
(
size_t
i
=
0
;
i
<
SMGR_N_STREAMS
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
// SHOULD NOT USE: cudaStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors
(
cudaStreamCreateWithFlags
(
streams
+
i
,
cudaStreamNonBlocking
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment