Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
7e949a62
Commit
7e949a62
authored
Dec 31, 2020
by
Rick Ho
Browse files
fix output bug
parent
4cb75d42
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
54 deletions
+64
-54
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+64
-54
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
7e949a62
...
@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
...
@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
local_input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_malloc
);
timestamp
(
t_malloc
);
...
@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
...
@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
expert_sz
+=
expert_n
[
i
];
expert_sz
+=
expert_n
[
i
];
}
}
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
sizeof
(
scalar_t
)
*
expert_sz
*
hidden_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tot_expert
;
++
i
)
{
...
@@ -166,32 +160,40 @@ void moe_cuda_forward_impl(
...
@@ -166,32 +160,40 @@ void moe_cuda_forward_impl(
local_input_buf
);
local_input_buf
);
h
->
sync
(
0
);
h
->
sync
(
0
);
ncclGroupStart
();
if
(
cm
->
rank
>
1
)
{
int
recv_ptr
=
0
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
int
send_id
=
i
+
j
*
num_expert
;
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
if
(
expert_count
[
send_id
])
{
ncclGroupStart
();
ncclSend
(
local_input_buf
+
expert_ptr
[
send_id
]
*
in_feat
,
int
recv_ptr
=
0
;
expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
ncclChar
,
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
j
,
int
send_id
=
i
+
j
*
num_expert
;
cm
->
ncclcomm
,
if
(
expert_count
[
send_id
])
{
h
->
getStream
(
0
));
ncclSend
(
local_input_buf
+
expert_ptr
[
send_id
]
*
in_feat
,
}
expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
int
recv_id
=
i
*
cm
->
size
+
j
;
ncclChar
,
if
(
all_expert_count
[
recv_id
])
{
j
,
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
cm
->
ncclcomm
,
all_expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
h
->
getStream
(
0
));
ncclChar
,
}
j
,
int
recv_id
=
i
*
cm
->
size
+
j
;
cm
->
ncclcomm
,
if
(
all_expert_count
[
recv_id
])
{
h
->
getStream
(
0
));
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
recv_ptr
+=
all_expert_count
[
recv_id
];
all_expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
recv_ptr
+=
all_expert_count
[
recv_id
];
}
}
}
}
}
ncclGroupEnd
();
}
else
{
input_buf
=
local_input_buf
;
}
}
ncclGroupEnd
();
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
h
->
sync
();
h
->
sync
();
...
@@ -244,32 +246,38 @@ void moe_cuda_forward_impl(
...
@@ -244,32 +246,38 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
ncclGroupStart
();
if
(
cm
->
rank
>
1
)
{
int
send_ptr
=
0
;
checkCudaErrors
(
cudaMalloc
(
&
local_output_buf
,
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
ncclGroupStart
();
int
recv_id
=
i
+
j
*
num_expert
;
int
send_ptr
=
0
;
if
(
expert_count
[
recv_id
])
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
ncclRecv
(
local_input_buf
+
expert_ptr
[
recv_id
]
*
in_feat
,
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
int
recv_id
=
i
+
j
*
num_expert
;
ncclChar
,
if
(
expert_count
[
recv_id
])
{
j
,
ncclRecv
(
local_output_buf
+
expert_ptr
[
recv_id
]
*
in_feat
,
cm
->
ncclcomm
,
expert_count
[
recv_id
]
*
in_feat
*
sizeof
(
scalar_t
),
h
->
getStream
(
0
));
ncclChar
,
}
j
,
int
send_id
=
i
*
cm
->
size
+
j
;
cm
->
ncclcomm
,
if
(
all_expert_count
[
send_id
])
{
h
->
getStream
(
0
));
ncclSend
(
input_buf
+
send_ptr
*
in_feat
,
}
all_expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
int
send_id
=
i
*
cm
->
size
+
j
;
ncclChar
,
if
(
all_expert_count
[
send_id
])
{
j
,
ncclSend
(
output_buf
+
send_ptr
*
in_feat
,
cm
->
ncclcomm
,
all_expert_count
[
send_id
]
*
in_feat
*
sizeof
(
scalar_t
),
h
->
getStream
(
0
));
ncclChar
,
send_ptr
+=
all_expert_count
[
send_id
];
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
));
send_ptr
+=
all_expert_count
[
send_id
];
}
}
}
}
}
ncclGroupEnd
();
}
else
{
local_output_buf
=
output_buf
;
}
}
ncclGroupEnd
();
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
<<<
batch_size
,
256
,
0
,
h
->
getStream
(
0
)
>>>
(
out_feat
,
d_pos
,
...
@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
...
@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
cudaFree
(
input_buf
);
cudaFree
(
input_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
output_buf
);
cudaFree
(
output_buf
);
cudaFree
(
local_input_buf
);
if
(
cm
->
rank
>
1
)
{
cudaFree
(
local_output_buf
);
cudaFree
(
local_input_buf
);
cudaFree
(
local_output_buf
);
}
cudaFree
(
d_pos
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
pos
;
delete
[]
gate
;
delete
[]
gate
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment