Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ff28081c
Commit
ff28081c
authored
Aug 25, 2023
by
zms1999
Browse files
[BUG FIX] wait torch stream
parent
c8633740
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
0 deletions
+12
-0
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+12
-0
No files found.
cuda/fastermoe/smart_schedule.h
View file @
ff28081c
...
@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_torch_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_torch_ready
+
i
);
}
}
// S_0 ... S_n
// S_0 ... S_n
...
@@ -200,6 +202,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -200,6 +202,7 @@ void fmoe_cuda_fused_forward_impl(
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
torch_stream
);
}
}
// Compute over shadowed experts
// Compute over shadowed experts
...
@@ -221,6 +224,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -221,6 +224,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_torch_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
@@ -246,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -246,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_torch_ready
[
i
]);
}
}
for
(
unsigned
i
=
0
;
i
<
params
.
size
();
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
params
.
size
();
++
i
)
{
cudaEventDestroy
(
evt_shadow
[
i
]);
cudaEventDestroy
(
evt_shadow
[
i
]);
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
delete
[]
output_torch_ready
;
}
}
...
@@ -292,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -292,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_torch_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
cudaEventCreate
(
output_torch_ready
+
i
);
}
}
// S_0 ... S_n
// S_0 ... S_n
...
@@ -352,6 +360,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -352,6 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
torch_stream
);
}
}
// Collect gradients for shadowed experts
// Collect gradients for shadowed experts
...
@@ -369,6 +378,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -369,6 +378,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
num_expert
),
output_torch_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
NCCL_SAFE_CALL
(
ncclGroupStart
());
...
@@ -396,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -396,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
cudaEventDestroy
(
output_torch_ready
[
i
]);
}
}
delete
[]
input_ready
;
delete
[]
input_ready
;
delete
[]
output_ready
;
delete
[]
output_ready
;
delete
[]
output_torch_ready
;
for
(
long
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
long
i
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
stored_models
[
i
+
rank
*
num_expert
])
{
if
(
stored_models
[
i
+
rank
*
num_expert
])
{
cudaEventDestroy
(
evt_reduce
[
i
]);
cudaEventDestroy
(
evt_reduce
[
i
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment