Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
7e949a62
Commit
7e949a62
authored
Dec 31, 2020
by
Rick Ho
Browse files
fix output bug
parent
4cb75d42
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
54 deletions
+64
-54
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+64
-54
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
7e949a62
...
@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
...
@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_malloc
);
timestamp
(
t_malloc
);
...
@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
...
@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
expert_sz
+=
expert_n
[
i
];
expert_sz
+=
expert_n
[
i
];
}
}
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
...
@@ -166,6 +160,11 @@ void moe_cuda_forward_impl(
...
@@ -166,6 +160,11 @@ void moe_cuda_forward_impl(
local_input_buf
);
local_input_buf
);
h
->
sync
(
0
);
h
->
sync
(
0
);
if
(
cm
->
rank
>
1
)
{
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
ncclGroupStart
();
ncclGroupStart
();
int
recv_ptr
=
0
;
int
recv_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
...
@@ -192,6 +191,9 @@ void moe_cuda_forward_impl(
...
@@ -192,6 +191,9 @@ void moe_cuda_forward_impl(
}
}
}
}
ncclGroupEnd
();
ncclGroupEnd
();
}
else
{
input_buf
=
local_input_buf
;
}
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
h
->
sync
();
h
->
sync
();
...
@@ -244,13 +246,16 @@ void moe_cuda_forward_impl(
...
@@ -244,13 +246,16 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
if
(
cm
->
rank
>
1
)
{
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
ncclGroupStart
();
ncclGroupStart
();
int
send_ptr
=
0
;
int
send_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
recv_id
=
i
+
j
*
num_expert
;
int
recv_id
=
i
+
j
*
num_expert
;
if
(
expert_count
[
recv_id
])
{
if
(
expert_count
[
recv_id
])
{
ncclRecv
(
local_
in
put_buf
+
expert_ptr
[
recv_id
]
*
in_feat
,
ncclRecv
(
local_
out
put_buf
+
expert_ptr
[
recv_id
]
*
in_feat
,
expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
...
@@ -259,7 +264,7 @@ void moe_cuda_forward_impl(
...
@@ -259,7 +264,7 @@ void moe_cuda_forward_impl(
}
}
int
send_id
=
i
*
cm
->
size
+
j
;
int
send_id
=
i
*
cm
->
size
+
j
;
if
(
all_expert_count
[
send_id
])
{
if
(
all_expert_count
[
send_id
])
{
ncclSend
(
in
put_buf
+
send_ptr
*
in_feat
,
ncclSend
(
out
put_buf
+
send_ptr
*
in_feat
,
all_expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
all_expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
...
@@ -270,6 +275,9 @@ void moe_cuda_forward_impl(
...
@@ -270,6 +275,9 @@ void moe_cuda_forward_impl(
}
}
}
}
ncclGroupEnd
();
ncclGroupEnd
();
}
else
{
local_output_buf
=
output_buf
;
}
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
...
@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
...
@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
cudaFree
(
input_buf
);
cudaFree
(
input_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
output_buf
);
cudaFree
(
output_buf
);
if
(
cm
->
rank
>
1
)
{
cudaFree
(
local_input_buf
);
cudaFree
(
local_input_buf
);
cudaFree
(
local_output_buf
);
cudaFree
(
local_output_buf
);
}
cudaFree
(
d_pos
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
pos
;
delete
[]
gate
;
delete
[]
gate
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment